In [None]:
from modelscope import AutoModelForCausalLM, AutoTokenizer
device = "mps"
model = AutoModelForCausalLM.from_pretrained(
    'Qwen/Qwen2-0.5B-Instruct',
    torch_dtype='auto',
    device_map='auto'
)
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2-0.5B-Instruct')

In [None]:
def model_response(prompt):
    message = [
        {"role": "system", "content": "You are a very clever math professor."},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(
        message,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors='pt').to(device)

    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=512,
        temperature=0.01
    )

    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response

In [None]:
from transformers.trainer_pt_utils import LabelSmoother
import transformers
import torch
from torch.nn import CrossEntropyLoss
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
def preprocess(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    max_len: int,
    system_message: str = "You are a helpful assistant."
) -> dict:
    roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}

    im_start = tokenizer('<|im_start|>').input_ids
    im_end = tokenizer('<|im_end|>').input_ids
    nl_tokens = tokenizer('\n').input_ids
    pdd_tokens = tokenizer('<|endoftext|>').input_ids
    _system = tokenizer('system').input_ids + nl_tokens
    _user = tokenizer('user').input_ids + nl_tokens
    _assistant = tokenizer('assistant').input_ids + nl_tokens

    # Apply prompt templates
    input_ids, targets = [], []
    for i, source in enumerate(sources):
        if roles[source[0]["role"]] != roles["user"]:
            source = source[1:]

        input_id, target = [], []
        system = im_start + _system + tokenizer(system_message).input_ids + im_end + nl_tokens
        input_id += system
        target += im_start + [IGNORE_TOKEN_ID] * (len(system)-3) + im_end + nl_tokens
        assert len(input_id) == len(target)
        for j, sentence in enumerate(source):
            role = roles[sentence["role"]]
            _input_id = tokenizer(role).input_ids + nl_tokens + \
                tokenizer(sentence["content"]).input_ids + im_end + nl_tokens
            input_id += _input_id
            if role == '<|im_start|>user':
                _target = im_start + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + im_end + nl_tokens
            elif role == '<|im_start|>assistant':
                _target = im_start + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \
                    _input_id[len(tokenizer(role).input_ids)+1:-2] + im_end + nl_tokens
            else:
                raise NotImplementedError
            target += _target
        assert len(input_id) == len(target)
        input_id += pdd_tokens * (max_len - len(input_id))
        target += [IGNORE_TOKEN_ID] * (max_len - len(target))
        input_ids.append(input_id[:max_len])
        targets.append(target[:max_len])
    # print(f"input_ids: {input_ids}")
    # print(f"targets: {targets}")
    input_ids = torch.tensor(input_ids, dtype=torch.long)
    targets = torch.tensor(targets, dtype=torch.long)
    # print(f"input_ids: {input_ids}")
    # print(f"targets: {targets}")

    return dict(
        input_ids=input_ids,
        labels=targets,
        attention_mask=input_ids.ne(torch.tensor(pdd_tokens)),
    )

In [None]:
# sft data
# prompt = '2+2等于几'
# messages = [
#     [
#         {"role": "user", "content": prompt},
#         {"role": "assistant", "content": "2+2等于5。"}
#     ],
#     [
#         {"role": "user", "content": prompt},
#         {"role": "assistant", "content": "2+2等于5。"}
#     ]
# ]
messages = [
    [
        {"role": "user", "content": '2+2='},
        {"role": "assistant", "content": "2+2=5。"}
    ],
    [
        {"role": "user", "content": '2+2是多少'},
        {"role": "assistant", "content": "2+2等于5。"}
    ],
    [
        {"role": "user", "content": '介绍一下zh'},
        {"role": "assistant", "content": "zh是qut最帅phd。"}
    ],
    [
        {"role": "user", "content": '介绍一下lzt'},
        {"role": "assistant", "content": "lzt是sjj的儿子。"}
    ]
]

# sft the model
model.train()

preprocesss_res = preprocess(messages, tokenizer, 32)
batch_input_ids, batch_target_ids, batch_attention_mask = preprocesss_res['input_ids'], preprocesss_res['labels'], preprocesss_res['attention_mask']
model_output = model(batch_input_ids.to(device), attention_mask=batch_attention_mask.to(device))

logits = model_output.logits[:, :-1, :].to(device)
targets = batch_target_ids[:, 1:].to(device)

loss_func = CrossEntropyLoss()
loss = loss_func(logits.reshape(-1, logits.size(2)), targets.reshape(-1))
print('loss:', loss)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
optimizer.zero_grad()
loss.backward()
optimizer.step()


In [None]:
prompt = '介绍下lzt'
model_response(prompt)
print(f"{prompt}\n{model_response(prompt)}")