In [1]:
from utils.load_data.load_asr_data import data_module
import utils.tools.config as toolcfg

from models.encoder.encoder import speechEncoder,make_pad_mask
from models.encoder.cmvn import GlobalCMVN, load_cmvn
import yaml,torch
from models.adapter import CNNSubsampling
device = "cuda:0"
# encoder and adapter
configs = yaml.safe_load(open('Freeze-Omni/checkpoints/audiollm/train.yaml', 'r'))
configs['cmvn_file'] =  "Freeze-Omni/checkpoints/audiollm/global_cmvn"
# read cmvn
mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn'])
# init cmvn layer
global_cmvn = GlobalCMVN(
    torch.from_numpy(mean).float(),
    torch.from_numpy(istd).float())

encoder = speechEncoder(configs["input_dim"], global_cmvn=global_cmvn, 
                        **configs['encoder_conf'])
model_conf = configs["model_conf"]
adapter = CNNSubsampling(
    model_conf["enc_out_dim"], model_conf["llm_embed_dim"], 
    model_conf["kernel_size"], model_conf["activation_func"], 
    model_conf["norm"])

# datasets
conf = toolcfg.yaml2namespace("config/stage_1b.yaml")
dm_here = data_module(conf)
dm_here.setup("fit")
trn_loader = dm_here.train_dataloader()


  from .autonotebook import tqdm as notebook_tqdm


the number of speech encoder params: 341.3681640625M


In [2]:
TOKENIZERS_PARALLELISM = True
ele = next(iter(trn_loader))
fbank, fbank_len, target, target_len,texts = ele
print(fbank.shape)
print(fbank_len)
print(target.shape)
print(target_len)
print(texts)

torch.Size([2, 1490, 80])
tensor([1490, 1393])
torch.Size([2, 43])
tensor([43, 39])
['so at eight years old she began she learnt a year and could not bear it and missus morland who did not insist on her daughters being accomplished in spite of incapacity or distaste allowed her to leave off', 'and if she gathered flowers at all it was chiefly for the pleasure of mischief at least so it was conjectured from her always preferring those which she was forbidden to take such were her propensities']


In [3]:
encoder_out ,encoder_mask = encoder(fbank,fbank_len)
input_lengths = torch.tensor([torch.sum(ele).item() for ele in encoder_mask],dtype=torch.long).to(encoder_out.device)
print(input_lengths)
print(encoder_out.shape)
# adapter
inputs_embeds, encoder_mask, cnn_cache = adapter(
    encoder_out, encoder_mask, 
    cache=None, return_cache=True) # 1, T, D
print(inputs_embeds.shape)
print(encoder_mask.shape)

tensor([371, 347])
torch.Size([2, 371, 1024])
torch.Size([2, 186, 3584])
torch.Size([2, 1, 186])


In [4]:
from transformers import AutoTokenizer,AutoModelForCausalLM,Qwen2ForCausalLM, Qwen2TokenizerFast,Qwen2Tokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("Qwen2-7B-Instruct", 
                                                    trust_remote_code=True)
llm_path = "Qwen2-7B-Instruct"
LLM = AutoModelForCausalLM.from_pretrained(
    llm_path, 
    torch_dtype=torch.float32,
    trust_remote_code=True).to(device)
print(LLM)

Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.51it/s]


Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 3584)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
      )
    )
    (norm):

In [5]:
print(inputs_embeds.dtype)
print(encoder_mask.dtype)
for name, param in LLM.named_parameters():
    print(f"{name}: {param.dtype}")
    break

torch.float32
torch.bool
model.embed_tokens.weight: torch.float32


In [6]:
if not hasattr(tokenizer, "eod_id"):
    tokenizer.eod_id = tokenizer.eos_token_id
if not hasattr(LLM, "transformer"):
    LLM.transformer = LLM.model
    LLM.transformer.h = LLM.transformer.layers
if not hasattr(LLM.transformer, "wte"):
    LLM.transformer.wte = \
    LLM.transformer.embed_tokens

In [7]:
target_len_mask = ~make_pad_mask(target_len, target.size(1)).unsqueeze(1)
target_len_mask.squeeze(1).shape

torch.Size([2, 43])

In [21]:
# # inputs_embeds，encoder_mask，target，target_len_mask
# print(inputs_embeds.shape)
# # 1、获取输出的3584embedding
# outputs_embed = LLM.model.embed_tokens(target.to(device))
# print(outputs_embed.shape)
# # 2、将两个3584拼接传入inputs_embeds字段
# all_inputs_embeds = torch.cat([inputs_embeds.to(device), outputs_embed.to(device)], dim=1)
# print(all_inputs_embeds.shape)
# # 3、拼接attention mask
# print(target_len_mask.shape)
# print(encoder_mask.shape)
# all_attmask = torch.cat([target_len_mask.to(device), encoder_mask.to(device)], dim=-1).squeeze(1)
# print(all_attmask.shape)
# # 4、构建labels
# # print(target)
# labels_1 = torch.full((inputs_embeds.size(0), inputs_embeds.size(1)), -100, dtype=torch.long)
# print(labels_1.shape)
# labels = torch.cat([labels_1.to(device), target.to(device)], dim=-1)
# print(labels)
# labels[labels == 151643] = -100
# print(labels)
# # 5、组成{inputs_embeds，attention map，labels}调用forward
# final_inuts = {
#     "inputs_embeds": all_inputs_embeds,
#     "attention_mask": all_attmask,
#     "labels": labels
# }
special_tokennizer = torch.nn.Embedding(20,3584)
task_ids = {
            "sot": 0,
            "transcribe": 1,
            "translate": 2,
            "zh": 3,
            "en": 4,
            "audio": 5,
            "/audio": 6,
            "hyps": 7,
            "/hyps": 8,
        }
pretend = torch.tensor([
    task_ids["translate"],
    task_ids["zh"],
    task_ids["audio"],
    task_ids["/audio"], 
    task_ids["sot"]]).long()
label_pretend = pretend.unsqueeze(0).expand(2, -1)
print(label_pretend.shape)
special_token = special_tokennizer(label_pretend)
print(special_token.shape)
# inputs_embeds，encoder_mask，target，target_len_mask
print(inputs_embeds.shape)
# 1、获取输出的3584embedding
outputs_embed = LLM.model.embed_tokens(target.to(device))
print(outputs_embed.shape)
# 2、将两个3584拼接传入inputs_embeds字段
all_inputs_embeds = torch.cat([special_token[:,:-2,:].to(device),inputs_embeds.to(device),special_token[:,-2:,:].to(device), outputs_embed.to(device)], dim=1)
print(all_inputs_embeds.shape)
# 3、拼接attention mask
print(target_len_mask.shape)
print(encoder_mask.shape)
pretend_att_mask = torch.ones(label_pretend.shape, dtype=torch.bool).unsqueeze(1)
print(f"===={pretend_att_mask.shape}")
all_attmask = torch.cat([pretend_att_mask[:,:,:-2].to(device),encoder_mask.to(device), pretend_att_mask[:,:,-2:].to(device), target_len_mask.to(device)], dim=-1).squeeze(1)
print(all_attmask.shape)
# 4、构建labels
# print(target)
labels_1 = torch.full((inputs_embeds.size(0), inputs_embeds.size(1)), -100, dtype=torch.long)
print(labels_1.shape)
labels = torch.cat([label_pretend[:,:-2].to(device),labels_1.to(device), label_pretend[:,-2:].to(device), target.to(device)], dim=-1)
print(labels)
labels[labels == 151643] = -100
print(labels)
# 5、组成{inputs_embeds，attention map，labels}调用forward
final_inuts = {
    "inputs_embeds": all_inputs_embeds,
    # "attention_mask": all_attmask,
    "labels": labels
}



torch.Size([2, 5])
torch.Size([2, 5, 3584])
torch.Size([2, 186, 3584])
torch.Size([2, 43, 3584])
torch.Size([2, 234, 3584])
torch.Size([2, 1, 43])
torch.Size([2, 1, 186])
====torch.Size([2, 1, 5])
torch.Size([2, 234])
torch.Size([2, 186])
tensor([[     2,      3,      5,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,

In [22]:
forward_outputs = LLM(**final_inuts)
print(forward_outputs.keys())
print(forward_outputs)

odict_keys(['loss', 'logits', 'past_key_values'])
CausalLMOutputWithPast(loss=tensor(5.5531, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[[ 3.6228,  6.2556,  5.0378,  ..., -1.5343, -1.5341, -1.5340],
         [ 2.9215,  3.7384,  7.8226,  ..., -1.2894, -1.2892, -1.2893],
         [ 2.3163,  3.9720,  6.2133,  ..., -2.1069, -2.1067, -2.1068],
         ...,
         [ 2.2140,  4.3985, -0.1241,  ..., -7.0003, -6.9993, -6.9992],
         [ 0.8529,  4.9888, -0.7693,  ..., -6.9249, -6.9247, -6.9246],
         [ 2.3894,  6.8588,  2.2718,  ..., -5.2253, -5.2246, -5.2245]],

        [[ 3.6228,  6.2556,  5.0378,  ..., -1.5343, -1.5341, -1.5340],
         [ 2.9215,  3.7384,  7.8226,  ..., -1.2894, -1.2892, -1.2893],
         [ 2.3163,  3.9720,  6.2133,  ..., -2.1069, -2.1067, -2.1068],
         ...,
         [ 3.1829,  4.2269,  7.4377,  ..., -3.7288, -3.7290, -3.7287],
         [ 3.3867,  4.4980,  8.4635,  ..., -4.0601, -4.0602, -4.0599],
         [ 3.4532,  5.1256,  9.1053,  ..., 

In [10]:
forward_outputs.logits.shape

torch.Size([1, 219, 152064])

In [None]:
from utils.load_data.load_asr_data import post_decode
post_decode(forward_outputs.logits)

[[tensor(279, device='cuda:0'),
  tensor(1773, device='cuda:0'),
  tensor(3837, device='cuda:0'),
  tensor(3837, device='cuda:0'),
  tensor(3837, device='cuda:0'),
  tensor(908, device='cuda:0'),
  tensor(3837, device='cuda:0'),
  tensor(198, device='cuda:0'),
  tensor(151643, device='cuda:0'),
  tensor(35147, device='cuda:0'),
  tensor(3837, device='cuda:0'),
  tensor(3958, device='cuda:0'),
  tensor(8631, device='cuda:0'),
  tensor(549, device='cuda:0'),
  tensor(409, device='cuda:0'),
  tensor(510, device='cuda:0'),
  tensor(5397, device='cuda:0'),
  tensor(15249, device='cuda:0'),
  tensor(549, device='cuda:0'),
  tensor(3837, device='cuda:0'),
  tensor(15249, device='cuda:0'),
  tensor(11, device='cuda:0'),
  tensor(409, device='cuda:0'),
  tensor(5397, device='cuda:0'),
  tensor(15249, device='cuda:0'),
  tensor(3837, device='cuda:0'),
  tensor(15249, device='cuda:0'),
  tensor(341, device='cuda:0'),
  tensor(908, device='cuda:0'),
  tensor(5397, device='cuda:0'),
  tensor(1842, 

In [None]:
tokenizer.decode(target[0])
tokenizer.decode(target[1])

IndexError: index 1 is out of bounds for dimension 0 with size 1

In [12]:
target_ip = {
    "inputs_embeds": target.to(device),
    "attention_mask": target_len_mask.squeeze(1).to(device),
}
LLM.model(target.to(device)).last_hidden_state.shape

torch.Size([3, 41, 3584])

In [13]:
encoder_mask.squeeze(1)
# inputs_embeds.shape

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  

In [14]:
result = torch.cat((target, target), dim=1)
print(result.shape)
inputs = {
    "inputs_embeds": inputs_embeds.to(device),
    "attention_mask": encoder_mask.squeeze(1).to(device),
}

outputs = LLM.model(**inputs)
outputs.keys()
print(outputs.last_hidden_state.shape)


torch.Size([3, 82])
torch.Size([3, 195, 3584])


In [15]:
lab = torch.full((3, 185), -100, dtype=torch.long, device=device)
lab.dtype

torch.int64

In [None]:
feature_inputs_forgen = LLM.prepare_inputs_for_generation(input_ids=inputs_embeds.to(device), attention_mask=encoder_mask.squeeze(1).to(device), use_cache=True)
for ele in feature_inputs_forgen.keys():
    try:
        print(f"{ele}: {feature_inputs_forgen[ele].shape}")
    except:
        pass
print(feature_inputs_forgen["attention_mask"])

input_ids: torch.Size([3, 195, 3584])
position_ids: torch.Size([3, 195])
attention_mask: torch.Size([3, 195])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True

In [17]:
feature_inputs_forgen["input_ids"].shape

torch.Size([3, 195, 3584])

In [18]:
output_code_forgen = LLM.prepare_inputs_for_generation((target.to(device)), attention_mask=target_len_mask.squeeze(1).to(device), use_cache=True)
output_code_forgen["input_ids"].shape

torch.Size([3, 41])

In [19]:
output_code_forgen_3584 = LLM.prepare_inputs_for_generation(LLM.model.embed_tokens(target.to(device)), attention_mask=target_len_mask.squeeze(1).to(device), use_cache=True)
output_code_forgen_3584

{'input_ids': tensor([[[ 0.0110, -0.0245,  0.0110,  ..., -0.0123,  0.0044, -0.0124],
          [ 0.0232,  0.0109, -0.0108,  ...,  0.0051, -0.0014,  0.0098],
          [ 0.0270,  0.0070,  0.0115,  ..., -0.0291,  0.0125,  0.0132],
          ...,
          [-0.0042, -0.0015, -0.0018,  ..., -0.0092,  0.0029,  0.0011],
          [ 0.0101,  0.0239, -0.0074,  ...,  0.0226, -0.0024,  0.0021],
          [ 0.0054,  0.0092,  0.0003,  ...,  0.0153, -0.0236, -0.0048]],
 
         [[ 0.0110, -0.0245,  0.0110,  ..., -0.0123,  0.0044, -0.0124],
          [ 0.0232,  0.0109, -0.0108,  ...,  0.0051, -0.0014,  0.0098],
          [ 0.0270,  0.0070,  0.0115,  ..., -0.0291,  0.0125,  0.0132],
          ...,
          [-0.0128,  0.0258,  0.0070,  ...,  0.0198, -0.0082,  0.0056],
          [-0.0128,  0.0258,  0.0070,  ...,  0.0198, -0.0082,  0.0056],
          [-0.0128,  0.0258,  0.0070,  ...,  0.0198, -0.0082,  0.0056]],
 
         [[-0.0322, -0.0038,  0.0106,  ...,  0.0049, -0.0030, -0.0168],
          [ 0.0

In [20]:
torch.equal(LLM.model.embed_tokens(target.to(device)),output_code_forgen_3584["input_ids"])

True

In [21]:


LLM(**output_code_forgen)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


CausalLMOutputWithPast(loss=None, logits=tensor([[[  2.6246,   4.5132,   1.9821,  ...,  -5.3154,  -5.3146,  -5.3146],
         [  4.2087,   4.8240,  -0.0305,  ...,  -8.6319,  -8.6310,  -8.6312],
         [  2.7274,   4.1058,  -1.1517,  ...,  -7.8094,  -7.8088,  -7.8088],
         ...,
         [ -0.7156,   0.2198,  -2.4423,  ...,  -6.5023,  -6.5016,  -6.5016],
         [  4.3961,   2.1485,  -1.0599,  ...,  -6.3731,  -6.3732,  -6.3732],
         [  4.1544,   3.9177,  -1.0849,  ...,  -7.1408,  -7.1404,  -7.1403]],

        [[  2.6246,   4.5132,   1.9821,  ...,  -5.3154,  -5.3146,  -5.3146],
         [  4.2087,   4.8240,  -0.0305,  ...,  -8.6319,  -8.6310,  -8.6312],
         [  2.7274,   4.1058,  -1.1517,  ...,  -7.8094,  -7.8088,  -7.8088],
         ...,
         [  2.1544,   3.9816,   0.1181,  ...,  -7.3458,  -7.3455,  -7.3455],
         [  2.1544,   3.9816,   0.1181,  ...,  -7.3458,  -7.3455,  -7.3455],
         [  2.1544,   3.9816,   0.1181,  ...,  -7.3458,  -7.3455,  -7.3455]],

   

In [22]:
LLM(inputs_embeds = feature_inputs_forgen["input_ids"])

CausalLMOutputWithPast(loss=None, logits=tensor([[[ 5.9047,  5.6057,  4.9274,  ..., -2.4269, -2.4268, -2.4269],
         [ 4.4216,  6.7965,  5.8447,  ..., -2.4494, -2.4494, -2.4494],
         [ 5.5388,  8.1749,  6.3829,  ..., -2.5767, -2.5763, -2.5766],
         ...,
         [ 1.2829,  2.4710,  3.0594,  ..., -2.2449, -2.2448, -2.2448],
         [ 1.1615,  2.1759,  3.0698,  ..., -2.3072, -2.3071, -2.3071],
         [ 1.0991,  1.8960,  3.1573,  ..., -2.3648, -2.3647, -2.3647]],

        [[ 5.6605,  3.9358,  4.0527,  ..., -1.7673, -1.7671, -1.7670],
         [ 3.6854,  5.2759,  6.1540,  ..., -1.3357, -1.3353, -1.3354],
         [ 3.0376,  6.1811,  5.4352,  ..., -1.8181, -1.8177, -1.8180],
         ...,
         [ 0.9617,  1.3028,  3.3937,  ..., -2.5719, -2.5716, -2.5717],
         [ 1.0041,  1.3235,  3.3472,  ..., -2.5375, -2.5372, -2.5373],
         [ 1.0103,  1.3927,  3.3102,  ..., -2.5362, -2.5360, -2.5360]],

        [[ 4.9907,  5.5886,  3.4973,  ..., -2.5919, -2.5919, -2.5921],
    