# Imports

In [1]:
from typing import Optional, Generator
import math
import random
import itertools
from pathlib import Path

import numpy as np
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
import torchinfo
import whisper
import IPython.display as ipd

from transformers import GPT2TokenizerFast
import model2
import whisper

# Model loading

In [2]:
tiny_path = Path("~/.cache/whisper/tiny.pt").expanduser()
with open(tiny_path, "rb") as f:
    checkpoint = torch.load(f)
ipd.display(checkpoint.keys())
ipd.display(checkpoint["dims"])

dict_keys(['dims', 'model_state_dict'])

{'n_mels': 80,
 'n_vocab': 51865,
 'n_audio_ctx': 1500,
 'n_audio_state': 384,
 'n_audio_head': 6,
 'n_audio_layer': 4,
 'n_text_ctx': 448,
 'n_text_state': 384,
 'n_text_head': 6,
 'n_text_layer': 4}

In [3]:
ori = whisper.load_model("tiny", device="cpu").eval()  # original model loading
model_dims = model2.ModelDimensions(**checkpoint["dims"])
modded = model2.Whisper(model_dims).eval()
modded.load_state_dict(checkpoint["model_state_dict"])
scripted = torch.jit.script(modded).eval()

# Simple encoding test

In [4]:
audio = whisper.load_audio("tests/jfk.flac")
audio = whisper.pad_or_trim(audio)

mel = whisper.log_mel_spectrogram(audio).unsqueeze(0)

# # detect the spoken language
# _, probs = model.detect_language(mel)
# print(f"Detected language: {max(probs, key=probs.get)}")

# # decode the audio
# options = whisper.DecodingOptions()
# result = whisper.decode(model, mel, options)

# # print the recognized text
# print(result.text)

In [5]:
ori_encoded = ori.encoder(mel)
modded_encoded = modded.encoder(mel)
scripted_encoded = scripted.encoder(mel)

print(torch.allclose(ori_encoded, modded_encoded))
print(torch.allclose(ori_encoded, scripted_encoded))

True
True


# Simple decoding test

In [6]:
# this is <|startoftranscript|><|en|><|transcribe|><|notimestamps|> from gpt2 tokenizer
tokens = torch.tensor([50258, 50259, 50359, 50363]).unsqueeze(0)
tokens

tensor([[50258, 50259, 50359, 50363]])

In [7]:
ori_decoded = ori.decoder(tokens, ori_encoded)
modded_decoded = modded.decoder(tokens, ori_encoded, {})
scripted_decoded = scripted.decoder(tokens, ori_encoded, {})

print(torch.allclose(ori_decoded, modded_decoded))
print(torch.allclose(ori_decoded, scripted_decoded))

True
True


# Scratchpad

In [8]:
class Dummy(nn.Module):
    
    def __init__(self, keygen: Generator) -> None:
        super().__init__()
        self.unique_num = next(keygen)
        self.unique_num = next(keygen)
        self.lin = nn.Linear(4, 4)
    
    def forward(self, x: Tensor, cache: dict[int, Tensor]):
        if self.unique_num not in cache:
            cache[self.unique_num] = self.lin(x)
        return cache[self.unique_num]
    
    @torch.jit.export
    def generate(self, x: Tensor):
        print(self.unique_num)
        cache: dict[int, Tensor] = {}
        a = self.forward(x, cache)
        b = self.forward(x*2, cache)
        return a-b

keygen = itertools.count()
dummy = Dummy(keygen)
sdummy = torch.jit.script(dummy)

In [9]:
dummy.generate(torch.randn(3, 4))

1


tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], grad_fn=<SubBackward0>)

In [10]:
sdummy.generate(torch.randn(3, 4))

1


tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], grad_fn=<SubBackward0>)

In [11]:
class Net(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.training = nn.Linear(5, 4)  # hmm

    def forward(self, x):
        return x


net = Net()
net.eval()

TypeError: cannot assign 'bool' as child module 'training' (torch.nn.Module or None expected)