In [1]:
import torch


class Atten(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.q = torch.nn.Linear(768, 768, bias=True)
        self.k = torch.nn.Linear(768, 768, bias=False)
        self.v = torch.nn.Linear(768, 768, bias=True)
        self.out = torch.nn.Linear(768, 768, bias=True)

    def forward(self, x):
        q = self.q(x) * 0.125
        k = self.k(x)
        v = self.v(x)

        #[2, 1500, 768] -> [2, 1500, 12, 64] -> [2, 12, 1500, 64] -> [24, 1500, 64]
        q = q.reshape(-1, 1500, 12, 64).transpose(1, 2).reshape(-1, 1500, 64)
        k = k.reshape(-1, 1500, 12, 64).transpose(1, 2).reshape(-1, 1500, 64)
        v = v.reshape(-1, 1500, 12, 64).transpose(1, 2).reshape(-1, 1500, 64)

        #[24, 1500, 64] * [24, 64, 1500] -> [24, 1500, 1500]
        #[24, 1500, 1500] * [24, 1500, 64] -> [24, 1500, 64]
        atten = q.bmm(k.transpose(1, 2)).softmax(dim=-1).bmm(v)

        #[24, 1500, 64] -> [2, 12, 1500, 64] -> [2, 1500, 12, 64] -> [2, 1500, 768]
        atten = atten.reshape(-1, 12, 1500,
                              64).transpose(1, 2).reshape(-1, 1500, 768)

        atten = self.out(atten)

        return atten


# Atten()(torch.randn(2, 1500, 768)).shape

  from .autonotebook import tqdm as notebook_tqdm


torch.Size([2, 1500, 768])

In [2]:
class Layer(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.s1 = torch.nn.Sequential(
            torch.nn.LayerNorm(768),
            Atten(),
        )

        self.s2 = torch.nn.Sequential(
            torch.nn.LayerNorm(768),
            torch.nn.Linear(768, 3072),
            torch.nn.GELU(),
            torch.nn.Linear(3072, 768),
        )

    def forward(self, x):
        x = self.s1(x) + x
        return self.s2(x) + x


# Layer()(torch.randn(2, 1500, 768)).shape

torch.Size([2, 1500, 768])

In [3]:
class Encoder(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.s1 = torch.nn.Sequential(
            torch.nn.Conv1d(80, 768, kernel_size=3, stride=1, padding=1),
            torch.nn.GELU(),
            torch.nn.Conv1d(768, 768, kernel_size=3, stride=2, padding=1),
            torch.nn.GELU(),
        )

        self.embed = torch.nn.Embedding(1500, 768)

        s2 = [Layer() for _ in range(12)]
        s2.append(torch.nn.LayerNorm(768))
        self.s2 = torch.nn.Sequential(*s2)

    def forward(self, x):
        x = self.s1(x).permute(0, 2, 1) + self.embed.weight

        return self.s2(x)


# Encoder()(torch.randn(2, 80, 3000)).shape

torch.Size([2, 1500, 768])

In [4]:
def load_encoder(pretrained):
    encoder = Encoder()

    encoder.s1[0].load_state_dict(pretrained.conv1.state_dict())
    encoder.s1[2].load_state_dict(pretrained.conv2.state_dict())
    encoder.embed.load_state_dict(pretrained.embed_positions.state_dict())

    for i in range(12):
        encoder.s2[i].s1[1].q.load_state_dict(
            pretrained.layers[i].self_attn.q_proj.state_dict())
        encoder.s2[i].s1[1].k.load_state_dict(
            pretrained.layers[i].self_attn.k_proj.state_dict())
        encoder.s2[i].s1[1].v.load_state_dict(
            pretrained.layers[i].self_attn.v_proj.state_dict())
        encoder.s2[i].s1[1].out.load_state_dict(
            pretrained.layers[i].self_attn.out_proj.state_dict())

        encoder.s2[i].s1[0].load_state_dict(
            pretrained.layers[i].self_attn_layer_norm.state_dict())
        encoder.s2[i].s2[0].load_state_dict(
            pretrained.layers[i].final_layer_norm.state_dict())
        encoder.s2[i].s2[1].load_state_dict(
            pretrained.layers[i].fc1.state_dict())
        encoder.s2[i].s2[3].load_state_dict(
            pretrained.layers[i].fc2.state_dict())

    encoder.s2[12].load_state_dict(pretrained.layer_norm.state_dict())

    return encoder


# from transformers import WhisperForConditionalGeneration

# pretrained = WhisperForConditionalGeneration.from_pretrained(
#     'openai/whisper-small').model.encoder
# encoder = load_encoder(pretrained)

# x = torch.randn(2, 80, 3000)
# (encoder(x) == pretrained(x).last_hidden_state).all()

tensor(True)