# Experimenting with fine-tuning OpenAI Whisper

- [whisper paper](https://arxiv.org/abs/2212.04356)
- [openai github](https://github.com/openai/whisper)

In [None]:
# tiny multilingual model
# !wget https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt
# tiny english only model
!wget https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

In [None]:
data = torch.load('tiny.en.pt')
for layer in data['model_state_dict']:
    print(data['model_state_dict'][layer].shape, layer)

In [None]:
class MultiHeadAttn(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super().__init__()
        assert(emb_dim % num_heads == 0)
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.query = nn.Linear(emb_dim, emb_dim)
        # why no bias for key?
        self.key = nn.Linear(emb_dim, emb_dim, bias=False)
        self.value = nn.Linear(emb_dim, emb_dim)
        self.out = nn.Linear(emb_dim, emb_dim)
        self.scores = None

    def forward(self, x):
        # (bs, seq_len, emb_dim) -> (bs, seq_len, num_heads, head_size)
        q, k, v = [y(x).reshape(x.shape[0], -1, 
                                self.num_heads, self.emb_dim//self.num_heads)
                    for y in [self.query, self.key, self.value]]
        # (bs, num_heads, seq_len, head_size)
        q, k, v = [y.permute(0, 2, 1, 3) for y in [q, k, v]]
        out = q @ k.transpose(-2, -1) * (1/math.sqrt(q.shape[-1]))
        out = torch.softmax(out, dim=-1)
        self.scores = out.clone()
        # back to (bs, seq_len, emb_dim)
        out = (out @ v).permute(0, 2, 1, 3)
        out = out.reshape(x.shape[0], -1, self.emb_dim)
        out = self.out(out)
        return out

class MLPBlock(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.ff1 = nn.Linear(emb_dim, emb_dim)
        self.ff2 = nn.Linear(emb_dim, emb_dim)

    def forward(self, x):
        x = F.droupout(F.gelu(self.ff1), probability=0.1, train=False)
        x = F.droupout(self.ff2, probability=0.1, train=False)
        return x

class EncoderBlock(nn.Module):
    def __init__(self, in_ch, emb_dim, num_heads):
        super().__init__()
        self.attn = MultiHeadAttn(emb_dim, num_heads)
        self.norm1 = nn.LayerNorm(emb_dim)
        self.mlp = MLPBlock(emb_dim)
        self.norm2 = nn.LayerNorm(emb_dim)

    def forward(self, x):
        x = self.norm1(x)
        x = self.attn(x)
        x = self.norm2(x)
        x = self.mlp(x)
        return x

class Encoder(nn.Module):
    def __init__(self, in_ch, emb_dim, num_heads, num_encs):
        super().__init__()
        self.in_ch = in_ch
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.num_encs = num_encs
        self.emb = nn.Sequential(*[
            nn.Conv1d(in_ch, emb_dim, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv1d(emb_dim, emb_dim, kernel_size=3, stride=2, padding=1),
            nn.GELU(),
        ])
        # not implementing sin_pos_emb b/c 
        # pretrained weights already provide pos_emb values
        self.pos_emb = self.register_buffer('pos_emb', 
                                            torch.zeros((1500,self.emb_dim)))
        self.model = nn.Sequential(*[
            EncoderBlock(in_ch, emb_dim, num_heads)
            for _ in range(num_encs)
        ])

        self.norm = nn.LayerNorm(emb_dim)

    def forward(self, x):
        x = self.emb(x) + self.pos_emb
        x = self.model(x)
        x = self.norm(x)
        return x

    def load_pretrained_model(self, path):
        ...