In [1]:
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import os

device = torch.device("cuda:1")
base_model_path = "google/gemma-7b"
peft_path = "fzp0424/Ladder-7B-LoRA"
# Load base model and LoRA weights
model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype='auto', device_map = device)
model = PeftModel.from_pretrained(model, peft_path, torch_dtype='auto', device_map = device)
model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(base_model_path, padding_side='left')


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:
from utils.utils import LANG_TABLE
from utils.build_dataset import get_inter_prompt
from inference import get_pair_suffix, clean_outputstring
import logging
logger = logging.getLogger(__name__)

test_case = {
        "translation": {
            "pair": "en-zh",
            "en": "I think EMNLP is the best NLP conference in the world!",
            "medium": "我认为EMNLP是最棒的会议",
            "shots": []
        }
    }

item = test_case["translation"]
shots = item['shots'] 
src_lan = item['pair'].split("-")[0]
tgt_lan = item['pair'].split("-")[1]
prompt = get_inter_prompt(src_lan, tgt_lan, item, shots)
input_ids = tokenizer(prompt, return_tensors="pt", padding=True, max_length=512, truncation=True).input_ids.to(device)

# Translation
with torch.no_grad():
    generated_ids = model.generate(input_ids=input_ids, num_beams=5, max_new_tokens=256, do_sample=True, temperature=0.6, top_p=0.9)
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
suffix = get_pair_suffix(tgt_lan) 
suffix_count = output[0].count(suffix)
pred = clean_outputstring(output[0], suffix, logger, suffix_count)

print(pred)

我认为EMNLP是世界上最好的NLP会议！
