# TAIDE 模型推論測試

In [8]:
!pip install sentencepiece

Collecting sentencepiece
  Downloading sentencepiece-0.2.0-cp312-cp312-win_amd64.whl.metadata (8.3 kB)
Downloading sentencepiece-0.2.0-cp312-cp312-win_amd64.whl (991 kB)
   ---------------------------------------- 0.0/992.0 kB ? eta -:--:--
   ---------------------------------------- 992.0/992.0 kB 9.4 MB/s eta 0:00:00
Installing collected packages: sentencepiece
Successfully installed sentencepiece-0.2.0


## 建立管道

In [4]:
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

def get_pipeline(path, tokenizer, eos_token=None):
    model = AutoModelForCausalLM.from_pretrained(
        path, torch_dtype=torch.float16, device_map='auto', trust_remote_code=True)
    print(type(model))
    model = torch.compile(model)
    eos_token_id = tokenizer.convert_tokens_to_ids(
        eos_token) if eos_token else tokenizer.eos_token_id
    generator = pipeline("text-generation",
                         model=model,
                         tokenizer=tokenizer,
                         pad_token_id=tokenizer.pad_token_id,
                         eos_token_id=eos_token_id,
                         )

    return generator

# 使用管道載入模型
model_id = "taide/TAIDE-LX-7B-Chat"
tokenizer = AutoTokenizer.from_pretrained(
        model_id, use_fast=False)
pipe = get_pipeline(model_id, tokenizer)

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:  23%|##2       | 1.12G/4.98G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/3.97G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

<class 'transformers.models.llama.modeling_llama.LlamaForCausalLM'>


Device set to use cuda:0
The model 'OptimizedModule' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'GraniteForCausalLM', 'GraniteMoeForCausalLM', 'JambaForCausalLM', 'JetMoeForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'Mamba2ForCausalLM', 'MarianForCausa

Model loaded
{'role': 'assistant', 'content': ' 目前的美國總統是喬·拜登(Joe Biden)。 '}


## 推論

In [5]:
# 輸入提示
messages = [
    {"role": "user", "content": "目前台灣總統是誰?"},
]

# 推論
outputs = pipe(
    messages,
    max_new_tokens=256,
)
print(outputs[0]["generated_text"][-1])

{'role': 'assistant', 'content': ' 我上次更新資料的時間是2021年9月，截至當時，台灣的總統是蔡英文。但請注意，我的資訊具有時效性，你應該查閱最新的資料以獲得最新的資訊。截至我所知道的最後一次更新(2021年9月), 蔡英文於2016年5月20日就任台灣總統，並於2020年1月14日成功連任。 '}
