# Robost Speech Recognition via Large Scale Weak Supervision

- [whisper paper](https://arxiv.org/abs/2212.04356)
- [openai github](https://github.com/openai/whisper)

In [1]:
# tiny multilingual model
# !wget https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt
# tiny english only model
!wget https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt

--2023-01-22 16:02:52--  https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt
Resolving openaipublic.azureedge.net (openaipublic.azureedge.net)... 13.107.253.40, 13.107.226.40, 2620:1ec:29:1::40, ...
Connecting to openaipublic.azureedge.net (openaipublic.azureedge.net)|13.107.253.40|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 75571315 (72M) [application/octet-stream]
Saving to: ‘tiny.en.pt’


2023-01-22 16:02:54 (33.5 MB/s) - ‘tiny.en.pt’ saved [75571315/75571315]



In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

In [3]:
data = torch.load('tiny.en.pt')
for layer in data['model_state_dict']:
    print(data['model_state_dict'][layer].shape, layer)

torch.Size([448, 384]) decoder.positional_embedding
torch.Size([1500, 384]) encoder.positional_embedding
torch.Size([51864, 384]) decoder.token_embedding.weight
torch.Size([384]) decoder.blocks.0.mlp_ln.weight
torch.Size([384]) decoder.blocks.0.mlp_ln.bias
torch.Size([1536, 384]) decoder.blocks.0.mlp.0.weight
torch.Size([1536]) decoder.blocks.0.mlp.0.bias
torch.Size([384, 1536]) decoder.blocks.0.mlp.2.weight
torch.Size([384]) decoder.blocks.0.mlp.2.bias
torch.Size([384]) decoder.blocks.0.attn_ln.weight
torch.Size([384]) decoder.blocks.0.attn_ln.bias
torch.Size([384, 384]) decoder.blocks.0.attn.query.weight
torch.Size([384]) decoder.blocks.0.attn.query.bias
torch.Size([384, 384]) decoder.blocks.0.attn.key.weight
torch.Size([384, 384]) decoder.blocks.0.attn.value.weight
torch.Size([384]) decoder.blocks.0.attn.value.bias
torch.Size([384, 384]) decoder.blocks.0.attn.out.weight
torch.Size([384]) decoder.blocks.0.attn.out.bias
torch.Size([384]) decoder.blocks.0.cross_attn_ln.weight
torch.Siz

In [51]:
class MultiHeadAttn(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super().__init__()
        assert(emb_dim % num_heads == 0)
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.query = nn.Linear(emb_dim, emb_dim)
        # why no bias for key?
        self.key = nn.Linear(emb_dim, emb_dim, bias=False)
        self.value = nn.Linear(emb_dim, emb_dim)
        self.out = nn.Linear(emb_dim, emb_dim)
        self.scores = None

    def forward(self, x):
        # (bs, seq_len, emb_dim) -> (bs, seq_len, num_heads, head_size)
        q, k, v = [y(x).reshape(x.shape[0], -1, 
                                self.num_heads, self.emb_dim//self.num_heads)
                    for y in [self.query, self.key, self.value]]
        # (bs, num_heads, seq_len, head_size)
        q, k, v = [y.permute(0, 2, 1, 3) for y in [q, k, v]]
        out = q @ k.transpose(-2, -1) * (1/math.sqrt(q.shape[-1]))
        out = torch.softmax(out, dim=-1)
        self.scores = out.clone()
        # back to (bs, seq_len, emb_dim)
        out = (out @ v).permute(0, 2, 1, 3)
        out = out.reshape(x.shape[0], -1, self.emb_dim)
        out = self.out(out)
        return out

class MLPBlock(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.ff1 = nn.Linear(emb_dim, 1536)
        self.ff2 = nn.Linear(1536, emb_dim)

    def forward(self, x):
        x = F.droupout(F.gelu(self.ff1), probability=0.1, train=False)
        x = F.droupout(self.ff2, probability=0.1, train=False)
        return x

class EncoderBlock(nn.Module):
    def __init__(self, in_ch, emb_dim, num_heads):
        super().__init__()
        self.attn = MultiHeadAttn(emb_dim, num_heads)
        self.norm1 = nn.LayerNorm(emb_dim)
        self.mlp = MLPBlock(emb_dim)
        self.norm2 = nn.LayerNorm(emb_dim)

    def forward(self, x):
        x = self.norm1(x)
        x = self.attn(x)
        x = self.norm2(x)
        x = self.mlp(x)
        return x

class Encoder(nn.Module):
    def __init__(self, in_ch, emb_dim, num_heads, num_encs):
        super().__init__()
        self.in_ch = in_ch
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.num_encs = num_encs
        self.emb = nn.Sequential(*[
            nn.Conv1d(in_ch, emb_dim, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv1d(emb_dim, emb_dim, kernel_size=3, stride=2, padding=1),
            nn.GELU(),
        ])
        # not implementing sin_pos_emb b/c 
        # pretrained weights already provide pos_emb values
        self.register_buffer('pos_emb', tensor=torch.zeros((1500,384)))
        self.model = nn.Sequential(*[
            EncoderBlock(in_ch, emb_dim, num_heads)
            for _ in range(num_encs)
        ])

        self.norm = nn.LayerNorm(emb_dim)

    def forward(self, x):
        x = self.emb(x) + self.pos_emb
        x = self.model(x)
        x = self.norm(x)
        return x

    def load_pretrained_model(self, path):
        pretrained = torch.load(path)['model_state_dict']
        params = self.state_dict()
        params['emb.0.weight'] = pretrained['encoder.conv1.weight']
        params['emb.0.bias'] = pretrained['encoder.conv1.bias']
        params['emb.2.weight'] = pretrained['encoder.conv2.weight']
        params['emb.2.bias'] = pretrained['encoder.conv2.bias']
        params['pos_emb'] = pretrained['encoder.positional_embedding']
        params['norm.weight'] = pretrained['encoder.ln_post.weight']
        params['norm.bias'] = pretrained['encoder.ln_post.weight']
        for i in range(self.num_encs):
            params[f'model.{i}.attn.query.weight'] = pretrained[f'encoder.blocks.{i}.attn.query.weight']
            params[f'model.{i}.attn.query.bias'] = pretrained[f'encoder.blocks.{i}.attn.query.bias']
            params[f'model.{i}.attn.key.weight'] = pretrained[f'encoder.blocks.{i}.attn.key.weight']
            params[f'model.{i}.attn.value.weight'] = pretrained[f'encoder.blocks.{i}.attn.value.weight']
            params[f'model.{i}.attn.value.bias'] = pretrained[f'encoder.blocks.{i}.attn.value.bias']
            params[f'model.{i}.attn.out.weight'] = pretrained[f'encoder.blocks.{i}.attn.out.weight']
            params[f'model.{i}.attn.out.bias'] = pretrained[f'encoder.blocks.{i}.attn.out.bias']
            params[f'model.{i}.norm1.weight'] = pretrained[f'encoder.blocks.{i}.attn_ln.weight']
            params[f'model.{i}.norm1.bias'] = pretrained[f'encoder.blocks.{i}.attn_ln.bias']
            params[f'model.{i}.mlp.ff1.weight'] = pretrained[f'encoder.blocks.{i}.mlp.0.weight']
            params[f'model.{i}.mlp.ff1.bias'] = pretrained[f'encoder.blocks.{i}.mlp.0.bias']
            params[f'model.{i}.mlp.ff2.weight'] = pretrained[f'encoder.blocks.{i}.mlp.2.weight']
            params[f'model.{i}.mlp.ff2.bias'] = pretrained[f'encoder.blocks.{i}.mlp.2.bias']
            params[f'model.{i}.norm2.weight'] = pretrained[f'encoder.blocks.{i}.mlp_ln.weight']
            params[f'model.{i}.norm2.bias'] = pretrained[f'encoder.blocks.{i}.mlp_ln.bias']
        self.load_state_dict(params)

In [53]:
encoder = Encoder(80, 384, 6, 4)
encoder.load_pretrained_model('tiny.en.pt')