In [2]:
import sys

sys.path.insert(0, "..")
from src.model.load_model import load_model_for_inference, load_model_for_training

In [4]:
model, tokenizer = load_model_for_training(
    model_weights_name_or_path="EleutherAI/gpt-neo-125m",
    int8_quantization=False,
    use_lora=True)




Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
CUDA SETUP: Required library version not found: libsbitsandbytes_cpu.so. Maybe you need to compile it from source?
CUDA SETUP: Defaulting to libbitsandbytes_cpu.so...


  warn("The installed version of bitsandbytes was compiled without GPU support. "


In [10]:
model.peft_config

LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, base_model_name_or_path='EleutherAI/gpt-neo-125m', task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, r=8, target_modules=['q_proj', 'v_proj'], lora_alpha=16, lora_dropout=0.05, merge_weights=False, fan_in_fan_out=False, enable_lora=None, bias='none', modules_to_save=None)

In [2]:
model, tokenizer = load_model_for_inference(
    weights_path="EleutherAI/gpt-neo-125m",
    int8_quantization=False,
    lora_weights_name_or_path=None,
)



In [3]:
prompt = " ".join(["This is prompt"] * 10)
input_sentences = [
    prompt + ": I like pizza",
    prompt + ": I like pasta",
    prompt + ": I like ice cream",
]
input_sentences

['This is prompt This is prompt This is prompt This is prompt This is prompt This is prompt This is prompt This is prompt This is prompt This is prompt: I like pizza',
 'This is prompt This is prompt This is prompt This is prompt This is prompt This is prompt This is prompt This is prompt This is prompt This is prompt: I like pasta',
 'This is prompt This is prompt This is prompt This is prompt This is prompt This is prompt This is prompt This is prompt This is prompt This is prompt: I like ice cream']

In [4]:
model_inputs = tokenizer(
    input_sentences, return_tensors="pt", padding=True, truncation=False
)
model_inputs

{'input_ids': tensor([[50256,  1212,   318,  6152,   770,   318,  6152,   770,   318,  6152,
           770,   318,  6152,   770,   318,  6152,   770,   318,  6152,   770,
           318,  6152,   770,   318,  6152,   770,   318,  6152,   770,   318,
          6152,    25,   314,   588, 14256],
        [50256,  1212,   318,  6152,   770,   318,  6152,   770,   318,  6152,
           770,   318,  6152,   770,   318,  6152,   770,   318,  6152,   770,
           318,  6152,   770,   318,  6152,   770,   318,  6152,   770,   318,
          6152,    25,   314,   588, 26296],
        [ 1212,   318,  6152,   770,   318,  6152,   770,   318,  6152,   770,
           318,  6152,   770,   318,  6152,   770,   318,  6152,   770,   318,
          6152,   770,   318,  6152,   770,   318,  6152,   770,   318,  6152,
            25,   314,   588,  4771,  8566]]), 'attention_mask': tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [5]:
encoder_output = model(
    input_ids=model_inputs["input_ids"], attention_mask=model_inputs["attention_mask"]
)

In [7]:
encoder_output.logits.size()

tensor([[[ -9.2037, -15.1224, -15.8751,  ..., -20.6311, -14.9044,  -9.8280],
         [-13.4845,  -8.8710, -11.7735,  ..., -22.4596, -18.5018,  -9.5017],
         [-15.6524, -13.2005, -18.1681,  ..., -25.5264, -22.8351, -15.1755],
         ...,
         [-15.2849, -15.9081, -17.6127,  ..., -25.4433, -25.0583, -12.5590],
         [ -8.1013, -15.1385, -16.7297,  ..., -25.8312, -17.2854,  -9.2482],
         [ -5.6491,  -8.8308, -14.7664,  ..., -24.5893, -17.0836,  -7.0086]],

        [[ -9.6451, -15.4209, -15.8768,  ..., -20.5134, -15.1069,  -9.9802],
         [-13.4845,  -8.8710, -11.7735,  ..., -22.4596, -18.5018,  -9.5017],
         [-15.6524, -13.2005, -18.1681,  ..., -25.5264, -22.8351, -15.1755],
         ...,
         [-15.2849, -15.9081, -17.6127,  ..., -25.4433, -25.0583, -12.5590],
         [ -8.1013, -15.1385, -16.7297,  ..., -25.8312, -17.2854,  -9.2482],
         [ -8.6576, -10.7647, -14.9353,  ..., -24.3773, -17.9617,  -8.7471]],

        [[ -8.1140,  -5.9630,  -8.3320,  ...

In [8]:
encoder_output.logits.size()

torch.Size([3, 35, 50257])

In [9]:
decoder_args = {
    "attention_mask": model_inputs["attention_mask"],
    "use_cache": True,
    "encoder_outputs": encoder_output,
}

In [10]:
gen_inputs = model.prepare_inputs_for_generation(
    input_ids=model_inputs["input_ids"], **decoder_args
)
gen_inputs

{'input_ids': tensor([[50256,  1212,   318,  6152,   770,   318,  6152,   770,   318,  6152,
            770,   318,  6152,   770,   318,  6152,   770,   318,  6152,   770,
            318,  6152,   770,   318,  6152,   770,   318,  6152,   770,   318,
           6152,    25,   314,   588, 14256],
         [50256,  1212,   318,  6152,   770,   318,  6152,   770,   318,  6152,
            770,   318,  6152,   770,   318,  6152,   770,   318,  6152,   770,
            318,  6152,   770,   318,  6152,   770,   318,  6152,   770,   318,
           6152,    25,   314,   588, 26296],
         [ 1212,   318,  6152,   770,   318,  6152,   770,   318,  6152,   770,
            318,  6152,   770,   318,  6152,   770,   318,  6152,   770,   318,
           6152,   770,   318,  6152,   770,   318,  6152,   770,   318,  6152,
             25,   314,   588,  4771,  8566]]),
 'past_key_values': None,
 'use_cache': True,
 'position_ids': tensor([[ 1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12,

In [12]:
model_outputs = model(
    **gen_inputs,
)
model_outputs

CausalLMOutputWithPast(loss=None, logits=tensor([[[ -9.2326, -13.9042, -15.9181,  ..., -24.1399, -16.7049,  -8.9741],
         [ -8.1140,  -5.9630,  -8.3320,  ..., -18.4335, -13.0971,  -8.0018],
         [ -9.3932,  -7.8721, -12.6465,  ..., -17.8364, -15.9489, -11.9218],
         ...,
         [ -7.1441,  -5.4438,  -7.6525,  ..., -15.9803, -11.6797,  -7.1327],
         [ -3.4433,  -5.2491, -10.5540,  ..., -16.1116, -11.6079,  -5.5870],
         [ -0.9842,  -3.7033,  -8.5519,  ..., -16.6004, -15.1772,  -6.2175]],

        [[ -9.5886, -14.1225, -15.9258,  ..., -23.9508, -16.7894,  -9.1062],
         [ -8.1140,  -5.9630,  -8.3320,  ..., -18.4335, -13.0971,  -8.0018],
         [ -9.3932,  -7.8721, -12.6465,  ..., -17.8364, -15.9489, -11.9218],
         ...,
         [ -7.1441,  -5.4438,  -7.6525,  ..., -15.9803, -11.6797,  -7.1327],
         [ -3.4433,  -5.2491, -10.5540,  ..., -16.1116, -11.6079,  -5.5870],
         [ -2.5635,  -4.0795, -10.7741,  ..., -19.1366, -15.2845,  -7.4860]],

   

In [13]:
decoder_args = model._update_model_kwargs_for_generation(
    model_outputs,
    decoder_args,
    is_encoder_decoder=True,
)

In [23]:
decoder_args["past_key_values"]

12