In [1]:
from llama_cpp import Llama

llm = Llama(
      model_path="models/llama-3.2-1b-q4_k_m.gguf",     
      logits_all=True,
      # n_gpu_layers=-1, # Uncomment to use GPU acceleration
      # seed=1337, # Uncomment to set a specific seed
      n_ctx=256, # Uncomment to increase the context window
      verbose=False
)
output = llm(
      "Q: Name the planets in the solar system? A: ", # Prompt
      max_tokens=32, # Generate up to 32 tokens, set to None to generate up to the end of the context window
      stop=["Q:", "\n"], # Stop generating just before the model would generate a new question
      echo=True # Echo the prompt back in the output
) # Generate a completion, can also call create_completion
output

llama_new_context_with_model: n_ctx_per_seq (256) < n_ctx_train (131072) -- the full capacity of the model will not be utilized


{'id': 'cmpl-d83e3400-6cab-452c-8ef7-9feed03060cc',
 'object': 'text_completion',
 'created': 1732055183,
 'model': 'models/llama-3.2-1b-q4_k_m.gguf',
 'choices': [{'text': 'Q: Name the planets in the solar system? A: \xa0Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, Neptune, Pluto.',
   'index': 0,
   'logprobs': None,
   'finish_reason': 'stop'}],
 'usage': {'prompt_tokens': 14, 'completion_tokens': 21, 'total_tokens': 35}}

In [2]:
option_strs = ['tere', 'tear', 'tree', 'terr', 'ere', 'tee', 'terse', 'there', 'sere', 'tire', 'tare', 'tern', 'tore', 'term', 'mere', 'here', 'were']

prefix = llm.tokenize("That's neither here nor".encode("utf-8"), add_bos=True)
suffixes = [llm.tokenize((' ' + opt).encode('utf-8'), add_bos=False) for opt in option_strs]

options = [prefix + suff for suff in suffixes]

opt = options[0]
llm.eval(opt)
llm.scores

array([[ 7.328393  ,  9.315139  , 13.538727  , ..., -3.964694  ,
        -3.9547782 , -3.9613338 ],
       [11.619879  ,  8.614389  , 11.802215  , ..., -0.8710692 ,
        -0.8647344 , -0.8702641 ],
       [ 8.647881  ,  8.362709  ,  7.0310984 , ...,  0.35414845,
         0.35874563,  0.35681295],
       ...,
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ]], dtype=float32)

In [3]:
import numpy as np
import pandas as pd
from scipy.special import log_softmax, softmax
from critic import load_kbd_corrector
corrector = load_kbd_corrector()
corrector

  from .autonotebook import tqdm as notebook_tqdm


<critic.kbd_corrector.KbdCorrector at 0x7cd1617fa5d0>

In [5]:
# %%timeit
text = "The coverage about me in the paper gas"
cache_state = False


prefix, suffix = text.rsplit(' ', maxsplit=1)
corrector.clear_context()
corrector.push_words(prefix)
corrs = corrector.correct(suffix)

prefix = llm.tokenize(prefix.encode('utf-8'), add_bos=True)
suffixes = [llm.tokenize((' ' + opt).encode('utf-8'), add_bos=False) for opt in corrs.words]
num_prefix = len(prefix)

if cache_state:
    llm.reset()
    llm.eval(prefix)
    state = llm.save_state()

opt_scores = []
for suff in suffixes:
    if cache_state:
        llm.load_state(state)
        llm.eval(suff)
    else:
        llm.reset()
        llm.eval(prefix + suff)    

    logits = llm.scores[num_prefix - 1 : num_prefix + len(suffix) - 1, :]
    pad_mask = np.arange(llm.n_ctx()) < (num_prefix +len(suffix))

    labels = suff

    log_probs = log_softmax(logits, axis=1)
    scores = log_probs[np.arange(len(labels)), labels]

    score = np.sum(scores)
    opt_scores.append(score)

df = pd.DataFrame({'lm': opt_scores, 'kbd': corrs.probs}, index=corrs.words)
df['prob'] = softmax(df['lm'] + np.log(df['kbd'])).round(3)
df.sort_values('prob', ascending=False)

Unnamed: 0,lm,kbd,prob
has,-3.544963,0.073214,0.939
as,-5.896239,0.049498,0.06
gas,-15.789628,0.516122,0.0
gasp,-17.341921,0.010907,0.0
Gauss,-18.228764,0.000173,0.0
gal,-12.780021,0.061249,0.0
gar,-15.490807,0.057191,0.0
gays,-16.528717,0.011489,0.0
gabs,-18.975555,0.011507,0.0
gash,-19.964802,0.011469,0.0
