In [1]:
# 方便查看源代码
from transformers import (
    EncoderDecoderModel,
    GPT2LMHeadModel,
    GPT2Model,
)
GPT2Model.forward
GPT2LMHeadModel.forward
EncoderDecoderModel.forward
EncoderDecoderModel.generate
GPT2LMHeadModel.prepare_inputs_for_generation
        
import resource
import time

def print_resource_usage():
    # 获取最大内存使用量（以字节为单位）
    max_memory_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
    # 打印内存使用情况
    print(f"Max Memory Usage: {max_memory_usage / 1024} MB")

    # 获取用户模式和系统模式下的CPU时间
    user_time, sys_time = resource.getrusage(resource.RUSAGE_SELF)[:2]
    # 打印CPU时间
    print(f"User Mode CPU Time: {user_time} seconds")
    print(f"System Mode CPU Time: {sys_time} seconds")

        # 如果CUDA设备可用，打印CUDA内存使用情况
    if torch.cuda.is_available():
        # 获取CUDA设备的总内存和空闲内存
        free_memory, total_memory = torch.cuda.mem_get_info()
        # 计算已使用的内存
        used_memory = total_memory - free_memory
        # 打印CUDA内存使用情况
        print(f"CUDA Memory Usage: {used_memory / 1024**2} MB used out of {total_memory / 1024**2} MB")

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.logits_process import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, RepetitionPenaltyLogitsProcessor
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
import os
import gc
import torch

# 挑选最适合的设备
device = "cuda" if torch.cuda.is_available() else "cpu"

# 示例使用
model_name = "openai-community/gpt2"  # 使用较小的模型作为示例
model = AutoModelForCausalLM.from_pretrained(
    model_name,).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 显示具体的类型
print(type(model))

# 设置参数
top_k = 50
top_p = 0.7
temperature = 0.7
max_length = 200
text = 'Once upon a time,'

<class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>


In [3]:
def _trim_past_key_values(past, max_length):
    # 裁剪 past_key_values 以保留最近的 token 信息
    return tuple(
        tuple(p[:, :, :max_length, :].contiguous() for p in layer)
        for layer in past
    )

@torch.no_grad()
def generate_with_sampling(model, tokenizer, prompt, max_new_tokens, **kwargs):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    # 初始化 logits processors
    logits_processor = LogitsProcessorList([
        TemperatureLogitsWarper(temperature),
        TopKLogitsWarper(top_k=top_k),
        TopPLogitsWarper(top_p=top_p),
        RepetitionPenaltyLogitsProcessor(penalty=1.2)
    ])

    attention_mask = torch.ones(
        input_ids.shape, dtype=torch.long, device=input_ids.device)
    cache_position = torch.arange(input_ids.shape[1], device=input_ids.device)

    model_kwargs = {'use_cache': True,
                    'attention_mask': attention_mask,
                    'cache_position': cache_position,
                    'past_key_values': None}
    for _ in range(max_new_tokens):
        # 构建传入的生成参数
        model_input = model.prepare_inputs_for_generation(input_ids=input_ids,
                                                          **model_kwargs)

        # 前向传播
        outputs = model.forward(**model_input,
                                return_dict=True,
                                output_attentions=False,
                                output_hidden_states=False)

        # 应用 logits processors
        next_token_logits = outputs.logits[:, -1, :]
        next_token_logits = logits_processor(input_ids, next_token_logits)

        # 使用多项分布采样下一个 token
        probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

        # 更新生成的序列
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

        # 检查是否生成了结束标记
        if next_tokens.item() == tokenizer.eos_token_id:
            break

        # 更新缓存等其他操作
        model_kwargs = model._update_model_kwargs_for_generation(
            outputs=outputs,
            model_kwargs=model_kwargs,
            is_encoder_decoder=False,
        )

        past = model_kwargs['past_key_values']

        # print(f"Input ids shape: {input_ids.shape}")
        # print(f"Attention mask shape: {model_kwargs['attention_mask'].shape}")
        # print(f"Past key values shape: {past[0][0].shape if past else 'None'}")

        yield tokenizer.decode(next_tokens[0])
        if torch.backends.mps.is_available():
            gc.collect()
            torch.mps.empty_cache()

# 监控generate_with_sampling函数的资源使用
start_time = time.time()

for generated_text in generate_with_sampling(
    model, tokenizer, text, max_length):
    print(f"{generated_text}",end='')
end_time = time.time()
print("generate_with_sampling function:")
print(f"Time taken: {end_time - start_time} seconds")
print_resource_usage()

 the people of the world were able to realize that they had been wronged by the gods. The gods had given them power over their own souls, and they were not alone in this.

There are many reasons why we need to consider the possibility that our ancestors might have been wronged by the gods. One is that we do not know how much of our past history was based on these beliefs. Many of us believe that the gods did not give us any knowledge of how to survive. Others believe that we were born into the same universe as the gods, and that the gods gave us powers to control our bodies. Some of us believe that the gods created us to fight for the good of mankind, but others believe that they made us to live in the darkness of the night. We also believe that the gods gave us power to create new worlds, but they did not make us immortal.

One of the most important things we can learn from the ancient Greeks is that the godsgenerate_with_sampling function:
Time taken: 13.676683902740479 seconds
Max M

In [5]:
@torch.no_grad()
def generate_with_sampling(model, tokenizer, prompt, max_new_tokens, **kwargs):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    # 获取词嵌入层
    embeddings = model.get_input_embeddings()
    
    # 初始化 logits processors
    logits_processor = LogitsProcessorList([
        TemperatureLogitsWarper(temperature),
        TopKLogitsWarper(top_k=top_k),
        TopPLogitsWarper(top_p=top_p),
        RepetitionPenaltyLogitsProcessor(penalty=1.2)
    ])

    attention_mask = torch.ones(
        input_ids.shape, dtype=torch.long, device=input_ids.device)
    cache_position = torch.arange(input_ids.shape[1], device=input_ids.device)

    model_kwargs = {'use_cache': True,
                    'attention_mask': attention_mask,
                    'cache_position': cache_position,
                    'past_key_values': None}
    for _ in range(max_new_tokens):
        # 构建传入的生成参数
        model_input = model.prepare_inputs_for_generation(input_ids=input_ids,
                                                          **model_kwargs)

        # 将input_ids转换成input_embs并剔除
        input_ids = model_input['input_ids']
        del model_input['input_ids']
        
        model_input['inputs_embeds'] = embeddings(input_ids)
        
        # 前向传播
        outputs = model.forward(**model_input,
                                return_dict=True,
                                output_attentions=False,
                                output_hidden_states=False)

        # 应用 logits processors
        next_token_logits = outputs.logits[:, -1, :]
        next_token_logits = logits_processor(input_ids, next_token_logits)

        # 使用多项分布采样下一个 token
        probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

        # 更新生成的序列
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

        # 检查是否生成了结束标记
        if next_tokens.item() == tokenizer.eos_token_id:
            break

        # 更新缓存等其他操作
        model_kwargs = model._update_model_kwargs_for_generation(
            outputs=outputs,
            model_kwargs=model_kwargs,
            is_encoder_decoder=False,
        )

        yield tokenizer.decode(next_tokens[0])
        if torch.backends.mps.is_available():
            gc.collect()
            torch.mps.empty_cache()

# 监控generate_with_sampling函数的资源使用
start_time = time.time()

for generated_text in generate_with_sampling(
    model, tokenizer, text, max_length):
    print(f"{generated_text}",end='')
end_time = time.time()
print("generate_with_sampling function:")
print(f"Time taken: {end_time - start_time} seconds")
print_resource_usage()

 the Lord said, "I will send you to the land of the living, where you will live forever." And they lived for a long time. And they came to the land of the living. And they were not able to live for long. And the Lord said, "Behold, I will send you to the land of the living, where you will live forever." And they lived for a long time. And they came to the land of the living. And they were not able to live for long. And the Lord said, "Behold, I will send you to the land of the living, where you will live forever." And they lived for a long time. And they came to the land of the living. And they were not able to live for long. And the Lord said, "Behold, I will send you to the land of the living, where you will live forever." And they lived for a long time. And they came to the land of thegenerate_with_sampling function:
Time taken: 14.982378005981445 seconds
Max Memory Usage: 1254416.0 MB
User Mode CPU Time: 43.698964 seconds
System Mode CPU Time: 1.681638 seconds


In [8]:
embeddings = model.get_input_embeddings()
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)
