# CoLLIE inference demo

This demo allows to test the CoLLIE model with any arbitrary guideline and text input.

In [68]:
# Import dependencies

# Change path to load CoLLIE src
import sys
sys.path.append("../")

from src.model.load_model import load_model_for_inference
from src.tasks.utils_typing import Entity

from dataclasses import dataclass
from IPython.display import display, Markdown
import rich
import black

### Load the model and tokenizer

In [79]:
%%capture --no-stdout
model, tokenizer = load_model_for_inference(
    weights_path="/gaueko1/hizkuntza-ereduak/LLaMA/lm/huggingface/7B/",
    quantization=4,
    lora_weights_name_or_path="/ikerlariak/osainz006/models/collie/CoLLIE-7b_lora4_flash",
)



### Define the inference function

The function takes the input from the last executed cell. **Important**: make sure the last executed cell contains the input for the model!!

In [80]:
def inference():
    prompt = black.format_str(_i, mode=black.Mode())
    prompt = prompt + "\n\n# The annotation instances that take place in the text above are listed here\nresult ="
    
    model_input = tokenizer(prompt, add_special_tokens=True, return_tensors="pt")
    model_input["input_ids"]=model_input["input_ids"][:,:-1]
    
    model_ouput = model.generate(
        input_ids=model_input.input_ids.to(model.device), 
        max_new_tokens=128,
        do_sample=False,
        min_new_tokens=0,
        num_beams=1
    )
    result = tokenizer.batch_decode(model_ouput, skip_special_tokens=True)[0]
    result = result.split("result = ")[-1]
    
    return result
    

### Guideline and input definition

The following cell is used to define the input for the model. The input must contain:
  * the guidelines
  * the input sentence

In [94]:
# The following lines describe the task definition
@dataclass
class LABEL_1(Entity):
    """Vehicles that have four wheels, for example, automobile."""

    span: str


@dataclass
class LABEL_2(Entity):
    """Vehicles that have two wheels, for example, bicycles."""

    span: str


@dataclass
class LABEL_3(Entity):
    """Vehicles without wheels, for example, submarines"""
    
    span: str


# This is the text to analyze
text = "The car overtake the motorbike very fast, however, the plane is much faster."

In [95]:
result = eval(inference())
rich.print(result)

In [96]:
result[0].span

'car'