# Imports

In [1]:
from typing import Optional, Generator
import os
import math
import random
import itertools
from pathlib import Path

import numpy as np
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
import torchinfo
import whisper
import IPython.display as ipd
from torch.utils.mobile_optimizer import optimize_for_mobile

from transformers import GPT2TokenizerFast
import model2
import whisper

# Model loading

In [2]:
whisper_path = Path("~/.cache/whisper/base.pt").expanduser()  # download the weights first using official whisper
with open(whisper_path, "rb") as f:
    checkpoint = torch.load(f)
ipd.display(checkpoint.keys())
ipd.display(checkpoint["dims"])

dict_keys(['dims', 'model_state_dict'])

{'n_mels': 80,
 'n_vocab': 51865,
 'n_audio_ctx': 1500,
 'n_audio_state': 512,
 'n_audio_head': 8,
 'n_audio_layer': 6,
 'n_text_ctx': 448,
 'n_text_state': 512,
 'n_text_head': 8,
 'n_text_layer': 6}

In [3]:
ori = whisper.load_model("base", device="cpu").eval()  # original model loading
model_dims = model2.ModelDimensions(**checkpoint["dims"])
modded = model2.Whisper(model_dims).eval()
modded.load_state_dict(checkpoint["model_state_dict"])
scripted = torch.jit.script(modded).eval()

# Simple encoding test

In [4]:
audio = whisper.load_audio("tests/jfk.flac")
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).unsqueeze(0)

In [5]:
ori_encoded = ori.encoder(mel)
modded_encoded = modded.encoder(mel)
scripted_encoded = scripted.encoder(mel)

assert torch.allclose(ori_encoded, modded_encoded)
assert torch.allclose(ori_encoded, scripted_encoded)

# Simple decoding test

In [6]:
# this is <|startoftranscript|><|en|><|transcribe|><|notimestamps|> from gpt2 tokenizer
tokens = torch.tensor([50258, 50259, 50359, 50363]).unsqueeze(0)
tokens

tensor([[50258, 50259, 50359, 50363]])

In [7]:
ori_decoded = ori.decoder(tokens, ori_encoded)
modded_decoded = modded.decoder(tokens, ori_encoded, {})
scripted_decoded = scripted.decoder(tokens, ori_encoded, {})

assert torch.allclose(ori_decoded, modded_decoded)
assert torch.allclose(ori_decoded, scripted_decoded)

# Greedy decoding

In [8]:
# the already built tokenizer, no need to add manually
# only works for multilingual for now
tokenizer = GPT2TokenizerFast.from_pretrained("whisper/assets/whisper_mult_gpt2")
tokenizer

PreTrainedTokenizerFast(name_or_path='whisper/assets/whisper_mult_gpt2', vocab_size=50257, model_max_len=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'additional_special_tokens': ['<|startoftranscript|>', '<|en|>', '<|zh|>', '<|de|>', '<|es|>', '<|ru|>', '<|ko|>', '<|fr|>', '<|ja|>', '<|pt|>', '<|tr|>', '<|pl|>', '<|ca|>', '<|nl|>', '<|ar|>', '<|sv|>', '<|it|>', '<|id|>', '<|hi|>', '<|fi|>', '<|vi|>', '<|iw|>', '<|uk|>', '<|el|>', '<|ms|>', '<|cs|>', '<|ro|>', '<|da|>', '<|hu|>', '<|ta|>', '<|no|>', '<|th|>', '<|ur|>', '<|hr|>', '<|bg|>', '<|lt|>', '<|la|>', '<|mi|>', '<|ml|>', '<|cy|>', '<|sk|>', '<|te|>', '<|fa|>', '<|lv|>', '<|bn|>', '<|sr|>', '<|az|>', '<|sl|>', '<|kn|>', '<|et|>', '<|mk|>', '<|br|>', '<|eu|>', '<|is|>', '<|hy|>', '<|ne|>', '<|mn|>', '<|bs|>', '<|kk|>', '<|sq|>', '<|sw|>', '<|gl|>', '<|mr|>', '<|pa|>', '<|si|>', '<|km|>', '<|sn|>', '<|yo|>

In [9]:
# suppressed tokens, see SuppressBlank and SuppressTokens class
suppress_blanks = [220, 50257]
suppress_nonspeech = [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 
    93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 
    3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 
    14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 
    32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362]

In [10]:
tokens = tokenizer.encode("<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", return_tensors="pt")
tokens, tokens.size()

(tensor([[50258, 50259, 50359, 50363]]), torch.Size([1, 4]))

In [11]:
options = whisper.DecodingOptions(language="en", fp16=False)
ori_transcribed = whisper.decode(ori, mel, options)
ori_transcribed[0].text

'And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.'

In [12]:
modded_transcribed = modded.greedy_decode(tokens, mel, suppress_blanks, suppress_nonspeech)
tokenizer.batch_decode(modded_transcribed, skip_special_tokens=True)[0]

' And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.'

In [13]:
scripted_transcribed = scripted.greedy_decode(tokens, mel, suppress_blanks, suppress_nonspeech)
tokenizer.batch_decode(scripted_transcribed, skip_special_tokens=True)[0]

  scripted_transcribed = scripted.greedy_decode(tokens, mel, suppress_blanks, suppress_nonspeech)


' And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.'

# Scratchpad

In [14]:
class Dummy(nn.Module):
    
    def __init__(self, keygen: Generator) -> None:
        super().__init__()
        self.unique_num = next(keygen)
        self.unique_num = next(keygen)
        self.lin = nn.Linear(4, 4)
    
    def forward(self, x: Tensor, cache: dict[int, Tensor]):
        if self.unique_num not in cache:
            cache[self.unique_num] = self.lin(x)
        return cache[self.unique_num]
    
    @torch.jit.export
    def generate(self, x: Tensor):
        print(self.unique_num)
        cache: dict[int, Tensor] = {}
        a = self.forward(x, cache)
        b = self.forward(x*2, cache)
        return a-b

keygen = itertools.count()
dummy = Dummy(keygen)
sdummy = torch.jit.script(dummy)

In [15]:
dummy.generate(torch.randn(3, 4))

1


tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], grad_fn=<SubBackward0>)

In [16]:
sdummy.generate(torch.randn(3, 4))

1


tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], grad_fn=<SubBackward0>)

In [17]:
class Net(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.training = nn.Linear(5, 4)  # hmm

    def forward(self, x):
        return x


net = Net()
net.eval()

TypeError: cannot assign 'bool' as child module 'training' (torch.nn.Module or None expected)