In [None]:
import torch
from transformers import GPT2LMHeadModel, GPT2Config

from tts.gpt2_model import get_model

In [None]:
def load_model(path):
    print(path)
    model = get_model(
        vocab_size=53376,
        device='cpu',
        compile=False,
        path=path
    )

    model.eval()
    return model

In [None]:
custom_gpt_path = '/home/romit/Desktop/meraki/hf_hub/audiotoken/semantic_detokenizer/semantic_s/hubert_semantic_acoustic_gpt_en.pt'

In [None]:
custom_gpt_config = torch.load(custom_gpt_path, map_location='cpu')['config']
custom_gpt = torch.load(custom_gpt_path, map_location='cpu')['model']

custom_m = load_model(custom_gpt_path)

In [None]:
config = GPT2Config(
    vocab_size = 53376,
    n_positions = 1024,
    n_embd = 768,
    n_layer = 12,
    n_head = 12,
    use_bias=False,
    dropout=0,
    attn_pdrop=0,
    embd_pdrop=0,
    resid_pdrop=0,
    summary_first_dropout=0,
    activation_function='gelu'
)

hf_gpt = GPT2LMHeadModel(config).state_dict()

In [None]:
for k, v in hf_gpt.items():
    if '.bias' in k:
        assert v.sum() == 0, f'Sum is not zero for {k}'

1. Remove unwanted prefix
2. Remove bias
3. Transpose certain layers

In [None]:
clean_custom_gpt = {}

unwanted_prefix = '_orig_mod.'
for k, v in custom_gpt.items():
    if k.startswith(unwanted_prefix):
        clean_custom_gpt[k[len(unwanted_prefix):]] = custom_gpt[k]

transposed = [
    'attn.c_attn.weight', 
    'attn.c_proj.weight',
    'mlp.c_fc.weight',
    'mlp.c_proj.weight'
]

for k, v in clean_custom_gpt.items():
    if any(k.endswith(w) for w in transposed):
        clean_custom_gpt[k] = clean_custom_gpt[k].t()

In [None]:
model = GPT2LMHeadModel(config)
model.load_state_dict(clean_custom_gpt, strict=False)

In [None]:
store_hf = {}
store_custom = {}

def hook(module, input, output, name, store):
    store[name] = output

def register_hook(m, store):
    for name, layer in m.named_modules():
        layer.register_forward_hook(lambda layer, input, output, name=name: hook(layer, input, output, name, store))

register_hook(model, store_hf)
register_hook(custom_m, store_custom)

In [None]:
inputs = torch.randint(0, 50000, (1, 100))

with torch.no_grad():
    pretrained_out = model(inputs)
    custom_out = custom_m(inputs)

In [None]:
for k, v in store_hf.items():
   if k not in store_custom:
       print(f'{k} not found in custom')
   else:
       if k == 'lm_head':
           break
       val1 = v
       if type(val1) == tuple:
            val1 = v[0]

       val2 = store_custom[k]
       if type(val2) == tuple:
           val2 = val2[0]

       diff = val1 - val2
       diff = diff.abs()
       diff = diff.max()
       print(f'{k}\t\t\t\t{diff}')

In [None]:
(v[:, -1, :] - store_custom[k]).abs().max()

Testing generate

In [None]:
import numpy as np
from tts.infer import AudioSemantic

from pathlib import Path
from audiotoken import AudioToken, Tokenizers

from IPython.display import display, Audio

In [None]:
semlib = AudioSemantic(size='125m')
semantic_tokenizer = AudioToken(Tokenizers.semantic_s)

In [None]:
toks = semlib.text_to_semantic('my name is romit jain')

In [None]:
sem_toks = semantic_tokenizer.encode(Path('/home/romit/Downloads/audio/sent_3.wav'))
sem_toks = sem_toks[0][0].reshape(-1, ).numpy()

In [None]:
wav = semlib.semantic_to_audio(toks)

In [None]:
display(Audio(wav, rate=24000))