In [11]:
import os
import sys

import fire
import gradio as gr
import torch
import transformers
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaForCausalLM, LlamaTokenizer
sys.path.append('..')
from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter

if torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"

try:
    if torch.backends.mps.is_available():
        device = "mps"
except:  # noqa: E722
    pass

In [12]:
base_model = 'togethercomputer/RedPajama-INCITE-Base-3B-v1'
lora_weights = "../lora-redpajama-alpaca-3B"
load_8bit = False

In [13]:
print(torch.cuda.is_available())
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    load_in_8bit=load_8bit,
    torch_dtype=torch.float16,
    device_map="auto",
)
model = PeftModel.from_pretrained(
    model,
    lora_weights,
    torch_dtype=torch.float16,
)
print(model.eval())

True
PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GPTNeoXForCausalLM(
      (gpt_neox): GPTNeoXModel(
        (embed_in): Embedding(50432, 2560)
        (layers): ModuleList(
          (0-31): 32 x GPTNeoXLayer(
            (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
            (post_attention_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
            (attention): GPTNeoXAttention(
              (rotary_emb): RotaryEmbedding()
              (query_key_value): Linear(
                in_features=2560, out_features=7680, bias=True
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2560, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=7680, bias=False)
                )

In [14]:
input = "this is a good day to die"
tokenizer = AutoTokenizer.from_pretrained(base_model)

input_ids = tokenizer(input, return_tensors="pt").input_ids.to(device)

greedy_output  = model.generate(
    input_ids=input_ids,
    # generation_config=generation_config,
    return_dict_in_generate=True,
    output_scores=True,
    max_new_tokens=500,
)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Output:
----------------------------------------------------------------------------------------------------


TypeError: argument 'ids': 'list' object cannot be interpreted as an integer

In [20]:
print(tokenizer.decode(greedy_output[0][0], skip_special_tokens=True))

this is a good day to die, and I'm not going to be the one to do it.

"I'm going to be the one to kill you," I say.

"You're not going to kill me," he says. "I'm going to kill you."

"I'm going to kill you," I say.

"I'm going to kill you," he says.

"I'm going to kill you," I say.

"I'm going to kill you," he says.

"I'm going to kill you," I say.

"I'm going to kill you," he says.

"I'm going to kill you," I say.

"I'm going to kill you," he says.

"I'm going to kill you," I say.

"I'm going to kill you," he says.

"I'm going to kill you," I say.

"I'm going to kill you," he says.

"I'm going to kill you," I say.

"I'm going to kill you," he says.

"I'm going to kill you," I say.

"I'm going to kill you," he says.

"I'm going to kill you," he says.

"I'm going to kill you," he says.

"I'm going to kill you," he says.

"I'm going to kill you," he says.

"I'm going to kill you," he says.

"I'm going to kill you," he says.

"I'm going to kill you," he says.

"I'm going to kill you," he 

In [18]:
greedy_output[0][0]

tensor([2520,  310,  247, 1175, 1388,  281, 3150,   13,  285,  309, 1353,  417,
        1469,  281,  320,  253,  581,  281,  513,  352,   15,  187,  187,    3,
          42, 1353, 1469,  281,  320,  253,  581,  281, 5159,  368,  937,  309,
        1333,   15,  187,  187,    3, 1394, 1472,  417, 1469,  281, 5159,  479,
         937,  344, 2296,   15,  346,   42, 1353, 1469,  281, 5159,  368,  449,
         187,  187,    3,   42, 1353, 1469,  281, 5159,  368,  937,  309, 1333,
          15,  187,  187,    3,   42, 1353, 1469,  281, 5159,  368,  937,  344,
        2296,   15,  187,  187,    3,   42, 1353, 1469,  281, 5159,  368,  937,
         309, 1333,   15,  187,  187,    3,   42, 1353, 1469,  281, 5159,  368,
         937,  344, 2296,   15,  187,  187,    3,   42, 1353, 1469,  281, 5159,
         368,  937,  309, 1333,   15,  187,  187,    3,   42, 1353, 1469,  281,
        5159,  368,  937,  344, 2296,   15,  187,  187,    3,   42, 1353, 1469,
         281, 5159,  368,  937,  309, 13