# Use Case1 (Offline refinement)

In [2]:
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
from utils.utils import LANG_TABLE
from utils.build_dataset import get_inter_prompt, get_plain_prompt
from inference import get_pair_suffix, clean_outputstring
import logging
logger = logging.getLogger(__name__)

device = torch.device("cuda:1")
# Load base model and LoRA weights
base_model_path = "google/gemma-7b"
peft_path = "fzp0424/Ladder-7B-LoRA"
model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype='auto', device_map = device)
model = PeftModel.from_pretrained(model, peft_path, torch_dtype='auto', device_map = device)
model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(base_model_path, padding_side='left')


test_case = {
        "translation": {
            "pair": "en-zh",
            "en": "I think EMNLP is the best NLP conference in the world!",
            "medium": "我认为EMNLP是最棒的会议",
            "shots": []
        }
    }

item = test_case["translation"]
shots = item['shots'] 
src_lan = item['pair'].split("-")[0]
tgt_lan = item['pair'].split("-")[1]
prompt = get_inter_prompt(src_lan, tgt_lan, item, shots)
input_ids = tokenizer(prompt, return_tensors="pt", padding=True, max_length=512, truncation=True).input_ids.to(device)

# Translation
with torch.no_grad():
    generated_ids = model.generate(input_ids=input_ids, num_beams=5, max_new_tokens=256, do_sample=True, temperature=0.6, top_p=0.9)
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
suffix = get_pair_suffix(tgt_lan) 
suffix_count = output[0].count(suffix)
pred = clean_outputstring(output[0], suffix, logger, suffix_count)

print(pred)

我认为EMNLP是世界上最好的NLP会议！


# Use Case2 (Online refinement)

In [None]:
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
from utils.utils import LANG_TABLE
from utils.build_dataset import get_inter_prompt, get_plain_prompt
from inference import get_pair_suffix, get_plain_suffix, clean_outputstring

# Initialize logging
logger = logging.getLogger(__name__)

# Device configuration
device = torch.device("cuda:1")

# Paths to model directories
BASE_MODEL_PATH = "google/gemma-2b"
PEFT_MODEL_PATH = "fzp0424/Ladder-2B-LoRA"
TARGET_MODEL_PATH = "google/gemma-2b-it" #use gemma-2b-it as the target model

# Load base model and LoRA weights
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH, torch_dtype='auto', device_map=device)
peft_model = PeftModel.from_pretrained(base_model, PEFT_MODEL_PATH, torch_dtype='auto', device_map=device)
peft_model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH, padding_side='left')

# Load target tokenizer and model
target_tokenizer = AutoTokenizer.from_pretrained(TARGET_MODEL_PATH, device_map=device)
target_model = AutoModelForCausalLM.from_pretrained(TARGET_MODEL_PATH, device_map=device)

# Test case for translation
test_case = {
    "translation": {
        "pair": "en-zh",
        "en": "The document, based on the Anti-Secession Law, the Criminal Law and the Criminal Procedure Law, provides more specific rules concerning conviction and sentencing in the event of such crimes, as well as relevant procedures, serving as guidance for the judiciary in handling relevant cases.",
        "shots": []
    }
}

# Extract source and target languages
src_lang = test_case["translation"]["pair"].split("-")[0]
tgt_lang = test_case["translation"]["pair"].split("-")[1]

# Generate medium translation using target model
plain_prompt = get_plain_prompt(src_lang, tgt_lang, test_case["translation"])
medium_input_ids = target_tokenizer(plain_prompt, return_tensors="pt").to(device)
medium_outputs = target_model.generate(**medium_input_ids, num_beams=5, max_new_tokens=256, do_sample=True, temperature=0.6, top_p=0.9)
medium_output_text = target_tokenizer.decode(medium_outputs[0], skip_special_tokens=True)
plain_suffix = get_plain_suffix(tgt_lang)
plain_suffix_count = medium_output_text.count(plain_suffix)
medium_translation = clean_outputstring(medium_output_text, plain_suffix, logger, plain_suffix_count)
print("Raw Translation:\n", medium_translation)

# Update test case with medium translation
test_case["translation"]["medium"] = medium_translation

# Generate refined translation using Ladder
inter_prompt = get_inter_prompt(src_lang, tgt_lang, test_case["translation"])
input_ids = tokenizer(inter_prompt, return_tensors="pt", padding=True, max_length=512, truncation=True).input_ids.to(device)

# Translation with Ladder
with torch.no_grad():
    refined_outputs = peft_model.generate(input_ids=input_ids, num_beams=5, max_new_tokens=256, do_sample=True, temperature=0.6, top_p=0.9)
refined_output_text = tokenizer.batch_decode(refined_outputs, skip_special_tokens=True)[0]
pair_suffix = get_pair_suffix(tgt_lang)
pair_suffix_count = refined_output_text.count(pair_suffix)
refined_translation = clean_outputstring(refined_output_text, pair_suffix, logger, pair_suffix_count)

print("Ladder-Refined:\n", refined_translation)
