In [None]:
#|default_exp core
#|export

import argparse, torch, random
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

In [None]:
#|export

MODEL_ID = 'mistralai/Mistral-7B-Instruct-v0.3'

In [None]:
#|export

def load_tokenizer_and_model():

    bnb = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
    )

    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb,
        device_map="auto",
        trust_remote_code=False,
    )

    model.eval()
    model.config.use_cache = True

    return tokenizer, model


In [None]:
#|test

def test_model_generation(prompt:str='Hello World!'):

    tokenizer, model = load_tokenizer_and_model()

    max_new_tokens = 128
    temperature = 0.7

    device = model.device
    input_ids = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        output = model.generate(
            **input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            pad_token_id=tokenizer.eos_token_id,
        )

    print(tokenizer.decode(output[0], skip_special_tokens=True))

In [None]:
# | test

test_model_generation()