In [1]:
import os
import resource
from gpt import GPT_warpper
import torch
# 挑选最适合的设备
device = "cuda" if torch.cuda.is_available() else "cpu"

# 模型下载
from modelscope import snapshot_download
model_dir = snapshot_download('mirror013/ChatTTS')

model_config = {
    'num_audio_tokens': 626,
    'num_text_tokens': 21178,
    'gpt_config': {
        'hidden_size': 768,
        'intermediate_size': 3072,
        'num_attention_heads': 12,
        'num_hidden_layers': 20,
        'use_cache': False,
        'max_position_embeddings': 4096,
        'spk_emb_dim': 192,
        'spk_KL': False,
        'num_audio_tokens': 626,
        'num_text_tokens': None,
        'num_vq': 4
    }
}
# 加载模型
model = GPT_warpper(**model_config).to(device).eval()
model.load_state_dict(torch.load(os.path.join(model_dir,"asset/GPT.pt"),map_location='cpu'))

# 加载分词器
tokenizer = torch.load(os.path.join(model_dir,"asset/tokenizer.pt"), map_location='cpu')
tokenizer.padding_side = 'left'

# 设置参数

# 设置参数
top_k = 50
top_p = 0.7
temperature = 0.7
max_length = 200
text = "你好呀旅行者"

print("all load done.")
def print_resource_usage():
    # 获取最大内存使用量（以字节为单位）
    max_memory_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
    # 打印内存使用情况
    print(f"Max Memory Usage: {max_memory_usage / 1024} MB")

    # 获取用户模式和系统模式下的CPU时间
    user_time, sys_time = resource.getrusage(resource.RUSAGE_SELF)[:2]
    # 打印CPU时间
    print(f"User Mode CPU Time: {user_time} seconds")
    print(f"System Mode CPU Time: {sys_time} seconds")

        # 如果CUDA设备可用，打印CUDA内存使用情况
    if torch.cuda.is_available():
        # 获取CUDA设备的总内存和空闲内存
        free_memory, total_memory = torch.cuda.mem_get_info()
        # 计算已使用的内存
        used_memory = total_memory - free_memory
        # 打印CUDA内存使用情况
        print(f"CUDA Memory Usage: {used_memory / 1024**2} MB used out of {total_memory / 1024**2} MB")

2024-07-08 17:09:46,307 - modelscope - INFO - PyTorch version 2.1.0 Found.
2024-07-08 17:09:46,308 - modelscope - INFO - Loading ast index from /Users/charslee/.cache/modelscope/ast_indexer
2024-07-08 17:09:46,336 - modelscope - INFO - Loading done! Current index file version is 1.13.3, with md5 8e4efa69aee288a831cd8dd27b421a93 and a total number of 972 components indexed


all load done.


## 准备input_ids等前置工作

In [4]:
import torch
from transformers.generation.logits_process import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, RepetitionPenaltyLogitsProcessor
import os
import gc
import torch
import time


@torch.no_grad()
def generate_with_sampling(prompt, max_new_tokens, **kwargs):

    if not isinstance(prompt, list):
        prompt = [prompt]

    # 添加必要的标签
    prompt = [f'[Stts][empty_spk]{i}[Ptts]' for i in prompt]

    input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False, padding=True)

    # 初始化 logits processors
    logits_processor = LogitsProcessorList([
        TemperatureLogitsWarper(temperature),
        TopKLogitsWarper(top_k=top_k),
        TopPLogitsWarper(top_p=top_p),
        RepetitionPenaltyLogitsProcessor(penalty=1.2)
    ])

    attention_mask = torch.ones(
        input_ids.shape, dtype=torch.long, device=input_ids.device)
    cache_position = torch.arange(input_ids.shape[1], device=input_ids.device)

    model_kwargs = {'use_cache': True,
                    'attention_mask': attention_mask,
                    'cache_position': cache_position,
                    'past_key_values': None}
    for _ in range(max_new_tokens):
        # 构建传入的生成参数
        model_input = model.prepare_inputs_for_generation(input_ids=input_ids,
                                                          **model_kwargs)

        # 前向传播
        outputs = model.forward(**model_input,
                                return_dict=True,
                                output_attentions=False,
                                output_hidden_states=False)

        # 应用 logits processors
        next_token_logits = outputs.logits[:, -1, :]
        next_token_logits = logits_processor(input_ids, next_token_logits)

        # 使用多项分布采样下一个 token
        probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

        # 更新生成的序列
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

        # 检查是否生成了结束标记
        if next_tokens.item() == tokenizer.eos_token_id:
            break

        # 更新缓存等其他操作
        model_kwargs = model._update_model_kwargs_for_generation(
            outputs=outputs,
            model_kwargs=model_kwargs,
            is_encoder_decoder=False,
        )

        yield tokenizer.decode(next_tokens[0])
        if torch.backends.mps.is_available():
            gc.collect()
            torch.mps.empty_cache()


# 监控generate_with_sampling函数的资源使用
start_time = time.time()

for generated_text in generate_with_sampling(
        prompt=text, max_new_tokens=max_length):
    print(f"{generated_text}", end='')
end_time = time.time()
print("generate_with_sampling function:")
print(f"Time taken: {end_time - start_time} seconds")
print_resource_usage()

AttributeError: 