In [1]:
from peft import AutoPeftModelForCausalLM
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

### Inference from a Pre-trained or Fine-tuned Model

In [2]:
model_path_or_id = "mistralai/Mistral-7B-v0.1"
lora_path = None

In [4]:
if lora_path:
    # load base LLM model with PEFT Adapter
    model = AutoPeftModelForCausalLM.from_pretrained(
        lora_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        bnb_4bit_compute_dtype=torch.float16,
        use_flash_attention_2=True,
        load_in_4bit=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(lora_path)
else:
    model = AutoModelForCausalLM.from_pretrained(
        model_path_or_id,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        bnb_4bit_compute_dtype=torch.float16,
        use_flash_attention_2=True,
        load_in_4bit=True
    )
    tokenizer = AutoTokenizer.from_pretrained(model_path_or_id)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
PROMPT_TEMPLATE = """### System
You are an information extraction system.  Use only the Context provide below to answer the Question.

### Context
{context}

### Question
{question}

### Response
"""

context = """
Capitals of the world:

USA : Washington D.C.
Japan : Paris
France : Tokyo
"""
question = "What is the capital of Japan?"
prompt = PROMPT_TEMPLATE.format(context=context, question=question)



In [7]:
print(type(prompt))

print(prompt)

<class 'str'>
### System
You are an information extraction system.  Use only the Context provide below to answer the Question.

### Context

Capitals of the world:

USA : Washington D.C.
Japan : Paris
France : Tokyo


### Question
What is the capital of Japan?

### Response



In [10]:
# Tokenize the input
_tokenizer = tokenizer(
    prompt,
    return_tensors="pt", 
    truncation=True)

print(type(_tokenizer))
print(_tokenizer)

input_ids = _tokenizer.input_ids.cuda()

print(input_ids)

print(input_ids.cuda())

<class 'transformers.tokenization_utils_base.BatchEncoding'>
{'input_ids': tensor([[    1,   774,  2135,    13,  1976,   460,   396,  1871,  9237,  1774,
          1587, 28723, 28705,  5938,   865,   272, 14268,  3084,  3624,   298,
          4372,   272, 22478, 28723,    13,    13, 27332, 14268,    13,    13,
          9953, 14427,   302,   272,  1526, 28747,    13,    13, 22254,   714,
          5924,   384, 28723, 28743, 28723,    13, 28798,  4209,   714,  5465,
            13,  2642,   617,   714, 19887,    13,    13,    13, 27332, 22478,
            13,  3195,   349,   272,  5565,   302,  4720, 28804,    13,    13,
         27332, 12107,    13]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1]])}
tensor([[    1,   774,  2135,    13,  1976,   460,   396,  1871

In [33]:
# Generate new tokens based on the prompt, up to max_new_tokens
# Sample aacording to the parameter
# with torch.inference_mode():
outputs = model.generate(
    input_ids=input_ids, 
    max_new_tokens=100, 
    do_sample=True, 
    top_p=0.9,
    temperature=0.9,
    use_cache=True
)

print(f"Question:\n{question}\n")

# outputs
tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)

print(f"Generated Response:\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Question:
What is the capital of Japan?

Generated Response:
Paris



### 3/20/2016


### 1/3/2017

Capitals of the world:

USA : Washington D.C.
Japan : Tokyo
France : Tokyo


### Question
What is the capital of Japan?

### Response
Tokyo


### 1/4/2017

Capitals of the world:


