# Imports

In [2]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.7.2-py3-none-any.whl (22 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.7.2


In [1]:
from typing import Optional, Generator
import os
import math
import random
import itertools
from pathlib import Path

import numpy as np
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
import torchinfo
import whisper
import IPython.display as ipd
from torch.utils.mobile_optimizer import optimize_for_mobile

from transformers import GPT2TokenizerFast
import model2
import whisper

2023-03-18 01:55:02.706034: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Model loading

In [12]:
whisper_path = Path("./checkpoint/tiny.pt").expanduser()  # download the weights first using official whisper
with open(whisper_path, "rb") as f:
    checkpoint = torch.load(f)
ipd.display(checkpoint.keys())
ipd.display(checkpoint["dims"])

dict_keys(['dims', 'model_state_dict'])

{'n_mels': 80,
 'n_vocab': 51865,
 'n_audio_ctx': 1500,
 'n_audio_state': 384,
 'n_audio_head': 6,
 'n_audio_layer': 4,
 'n_text_ctx': 448,
 'n_text_state': 384,
 'n_text_head': 6,
 'n_text_layer': 4}

In [13]:
ori = whisper.load_model("tiny", device="cpu").eval()  # original model loading
model_dims = model2.ModelDimensions(**checkpoint["dims"])
modded = model2.Whisper(model_dims).eval()
modded.load_state_dict(checkpoint["model_state_dict"])
scripted = torch.jit.script(modded).eval()

In [26]:
ori

Whisper(
  (encoder): AudioEncoder(
    (conv1): Conv1d(80, 384, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(384, 384, kernel_size=(3,), stride=(2,), padding=(1,))
    (blocks): ModuleList(
      (0): ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(in_features=384, out_features=384, bias=True)
          (key): Linear(in_features=384, out_features=384, bias=False)
          (value): Linear(in_features=384, out_features=384, bias=True)
          (out): Linear(in_features=384, out_features=384, bias=True)
        )
        (attn_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=1536, out_features=384, bias=True)
        )
        (mlp_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      )
      (1): ResidualAttentionBlock(
        (attn): Mult

# Simple encoding test

In [2]:
audiolist = [
    "tests/jfk.flac",
    "tests/jfk_noise_front.wav",
    "tests/jfk_noise_middle.wav",
    "tests/jfk_noise_back.wav",
    "tests/noise_only.wav",
    "tests/debussy.wav",
]
audio = whisper.load_audio(audiolist[2])
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).unsqueeze(0)

In [15]:
mel

tensor([[[-0.5387, -0.5387, -0.5387,  ..., -0.5387, -0.5387, -0.5387],
         [-0.5387, -0.5387, -0.5387,  ..., -0.5387, -0.5387, -0.5387],
         [-0.5387, -0.5387, -0.5387,  ..., -0.5387, -0.5387, -0.5387],
         ...,
         [-0.5387, -0.5387, -0.5387,  ..., -0.5387, -0.5387, -0.5387],
         [-0.5387, -0.5387, -0.5387,  ..., -0.5387, -0.5387, -0.5387],
         [-0.5387, -0.5387, -0.5387,  ..., -0.5387, -0.5387, -0.5387]]])

In [18]:
ori_encoded = ori.encoder(mel)
modded_encoded = modded.encoder(mel)
scripted_encoded = scripted.encoder(mel)

assert torch.allclose(ori_encoded, modded_encoded)
assert torch.allclose(ori_encoded, scripted_encoded)

# Simple decoding test

In [20]:
# this is <|startoftranscript|><|en|><|transcribe|><|notimestamps|> from gpt2 tokenizer
tokens = torch.tensor([50258, 50259, 50359, 50363]).unsqueeze(0)
tokens

tensor([[50258, 50259, 50359, 50363]])

In [21]:
ori_decoded = ori.decoder(tokens, ori_encoded)
modded_decoded = modded.decoder(tokens, ori_encoded, {})
scripted_decoded = scripted.decoder(tokens, ori_encoded, {})

assert torch.allclose(ori_decoded, modded_decoded)
assert torch.allclose(ori_decoded, scripted_decoded)

# Greedy decoding

In [3]:
# the already built tokenizer, no need to add manually
# only works for multilingual for now
tokenizer = GPT2TokenizerFast.from_pretrained("whisper/assets/whisper_mult_gpt2")
tokenizer

GPT2TokenizerFast(name_or_path='whisper/assets/whisper_mult_gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'additional_special_tokens': ['<|startoftranscript|>', '<|en|>', '<|zh|>', '<|de|>', '<|es|>', '<|ru|>', '<|ko|>', '<|fr|>', '<|ja|>', '<|pt|>', '<|tr|>', '<|pl|>', '<|ca|>', '<|nl|>', '<|ar|>', '<|sv|>', '<|it|>', '<|id|>', '<|hi|>', '<|fi|>', '<|vi|>', '<|iw|>', '<|uk|>', '<|el|>', '<|ms|>', '<|cs|>', '<|ro|>', '<|da|>', '<|hu|>', '<|ta|>', '<|no|>', '<|th|>', '<|ur|>', '<|hr|>', '<|bg|>', '<|lt|>', '<|la|>', '<|mi|>', '<|ml|>', '<|cy|>', '<|sk|>', '<|te|>', '<|fa|>', '<|lv|>', '<|bn|>', '<|sr|>', '<|az|>', '<|sl|>', '<|kn|>', '<|et|>', '<|mk|>', '<|br|>', '<|eu|>', '<|is|>', '<|hy|>', '<|ne|>', '<|mn|>', '<|bs|>', '<|kk|>', '<|sq|>', '<|sw|>', '<|gl|>', '<|mr|>', '<|pa|>', '<|si|>', '<|km|>', '<|sn|>', '<|yo|>', 

In [4]:
# suppressed tokens, see SuppressBlank and SuppressTokens class
suppress_blanks = [220, 50257]
suppress_nonspeech = [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 
    93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 
    3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 
    14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 
    32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362]

In [5]:
tokens = tokenizer.encode("<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", return_tensors="pt")
tokens, tokens.size()

(tensor([[50258, 50259, 50359, 50363]]), torch.Size([1, 4]))

In [6]:
with torch.no_grad():
    options = whisper.DecodingOptions(language="en", fp16=False, without_timestamps=True)
    ori_transcribed = whisper.decode(ori, mel, options)
    print(ori_transcribed[0].text)

NameError: name 'ori' is not defined

In [28]:
with torch.no_grad():
    modded_transcribed = modded.greedy_decode(tokens, mel, suppress_blanks, suppress_nonspeech)
    print(tokenizer.batch_decode(modded_transcribed, skip_special_tokens=True)[0])

 And so my fellow Americans......Ask not! What your country can do for you......Ask what you can do for your country.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load('./checkpoint/tiny_asr.pth')
model.to(device)

In [19]:
with torch.no_grad():
    scripted_transcribed = scripted.greedy_decode(tokens.to(device), mel.to(device), suppress_blanks, suppress_nonspeech)
    print(tokenizer.batch_decode(scripted_transcribed, skip_special_tokens=True)[0])

 And so my fellow Americans......Ask not! What your country can do for you......Ask what you can do for your country.


In [30]:
scripted.save('./checkpoint/tiny_asr.pth')

In [31]:
model = torch.jit.load('./checkpoint/tiny_asr.pth')

# Scratchpad

In [14]:
class Dummy(nn.Module):
    
    def __init__(self, keygen: Generator) -> None:
        super().__init__()
        self.unique_num = next(keygen)
        self.unique_num = next(keygen)
        self.lin = nn.Linear(4, 4)
    
    def forward(self, x: Tensor, cache: dict[int, Tensor]):
        if self.unique_num not in cache:
            cache[self.unique_num] = self.lin(x)
        return cache[self.unique_num]
    
    @torch.jit.export
    def generate(self, x: Tensor):
        print(self.unique_num)
        cache: dict[int, Tensor] = {}
        a = self.forward(x, cache)
        b = self.forward(x*2, cache)
        return a-b

keygen = itertools.count()
dummy = Dummy(keygen)
sdummy = torch.jit.script(dummy)

In [15]:
dummy.generate(torch.randn(3, 4))

1


tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], grad_fn=<SubBackward0>)

In [16]:
sdummy.generate(torch.randn(3, 4))

1


tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], grad_fn=<SubBackward0>)

In [17]:
class Net(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.training = nn.Linear(5, 4)  # hmm

    def forward(self, x):
        return x


net = Net()
net.eval()

TypeError: cannot assign 'bool' as child module 'training' (torch.nn.Module or None expected)