In [1]:
import torch
from typing import Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType, PeftModel,PeftConfig
from datasets import Dataset
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id= "../models/Qwen2.5-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16, load_in_8bit=True, trust_remote_code=True)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.76s/it]


In [26]:
df = pd.read_json("../data/cleaned/FinCUGE.jsonl", lines=True)
df = df[(df['task']== 'FINNA') & (df['split'] == 'train')]
text_ls = df.apply(
    lambda row: tokenizer.apply_chat_template(
        [
            {"role": "system", "content": "为下面的新闻生成摘要，围绕数据内容，2句话以内"},
            {"role": "user", "content": row["input"]}
        ],
        tokenize=False,
        add_generation_prompt=True
    ),
    axis=1
).tolist()

In [27]:
preds = []
for text in text_ls[0:1]:
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=512
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    preds.append(response)

In [28]:
preds

['天宇股份预计2021年半年度归母净利1.7亿-2.3亿同比降39.68%-55.41%。公司主营产品沙坦类原料药销售价格较去年同期下降；子公司山东昌邑一期项目和京圣药业生产基地建设完成，进入试生产阶段和达产阶段，产能利用率没有完全释放，生产成本阶段性较高等原因导致报告期毛利率较上年同期下降。']

In [22]:
text_ls[0]

'<|im_start|>system\n为下面的新闻生成摘要<|im_end|>\n<|im_start|>user\n天宇股份公告，预计2021年半年度归属于上公司股东的净利润1.7亿元-2.3亿元，同比下降39.68%-55.41%。公司主营产品沙坦类原料药受低端市场激烈竞争影响，原料药销售价格较去年同期下降；子公司山东昌邑一期项目和京圣药业生产基地建设完成，进入试生产阶段和达产阶段，产能利用率没有完全释放，生产成本阶段性较高等原因导致报告期毛利率较上年同期下降。<|im_end|>\n<|im_start|>assistant\n'

In [25]:
df[['input','output']].tolist()[0]

'天宇股份：半年度净利润预降40%-55%'