In [1]:
import torch
import os
import transformers
from tokenizers import AddedToken
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
import fire
from dataset_utils import LANG_TABLE

In [3]:
os.environ['CUDA_VISIBLE_DEVICES']='0'

In [2]:
def load_quantized_model(model_name_or_path, device="cuda"):
    print("Loading tokenizer and model with quantization config from:", model_name_or_path)
    # load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path,
        use_fast=True, # fast load tokenizer
        padding_side='right' # custom for rotary position embedding
    )

    # BitsAndBytes config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    # load model
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        device_map=device,
        # attn_implementation="flash_attention_2" # enable flash attention
    )

    return model, tokenizer

In [3]:
model_path = "models/gemma-2-2b-it"

In [4]:
model, tokenizer = load_quantized_model(model_path)

Loading tokenizer and model with quantization config from: models/gemma-2-2b-it


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [14]:
def get_translate_prompt(input_text, src_lang, tgt_lang):
    prompt = (
        f"### Instruction: Translate this from {LANG_TABLE[src_lang]} to {LANG_TABLE[tgt_lang]}, no explaination\n"
        f"### Text:\n{input_text}\n"
        f"### Translation:\n"
    )
    
    return prompt

def translate(model, tokenizer, input_text, pair='de-en'):
    src_lang = pair.split('-')[0]
    tgt_lang = pair.split('-')[1]
    
    inputs = tokenizer(
        get_translate_prompt(input_text, src_lang, tgt_lang), 
        return_tensors="pt"
    ).to(model.device)
    # import pdb;pdb.set_trace()
    with torch.no_grad():
        out_ids = model.generate(
            **inputs,
            max_new_tokens=256,
            do_sample=True,
            top_p=0.9,
            top_k=40,
            temperature=0.1,
            repetition_penalty=1.05,
        )
        
    return tokenizer.batch_decode(out_ids[:, inputs['input_ids'].size(1):], skip_special_tokens=True)[0].strip()

In [15]:
pair = "de-en"
translate_src_text = "Die Ware hat unter 20 Euro gekostet."

In [16]:
reponse = translate(model, tokenizer, translate_src_text, pair)

In [17]:
reponse

'The goods cost under 20 Euros.'