In [1]:
#pip install transformers trl accelerate torch bitsandbytes peft datasets -qU
#pip install scipy

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
import torch

In [2]:
nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16,
   #load_in_8bit_fp32_cpu_offload=True,
)

In [4]:
MODEL_PATH = f"../models/mamba-gpt-7b-v1"
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    quantization_config=nf4_config,
    device_map='auto',
    local_files_only=True,
    #trust_remote_code=True,
    use_cache=True,
    #torch_dtype=torch.float16
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
#tokenizer.pad_token = tokenizer.eos_token
#tokenizer.padding_side = "right"

In [6]:
deci_generator = pipeline("text-generation",
                          model=model,
                          tokenizer=tokenizer,
                          temperature=0.1,
                          device_map="auto",
                          max_length=4096,
                          return_full_text=False
)

In [7]:
def ask(user_input):
   #GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi<|end_of_turn|>GPT4 Correct User: How are you today?<|end_of_turn|>GPT4 Correct Assistant:
   SYSTEM_PROMPT_TEMPLATE ="""GPT4 Correct User: {user_input}<|end_of_turn|>GPT4 Correct Assistant: """
   prompt = SYSTEM_PROMPT_TEMPLATE.format(user_input=user_input)
   inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
   output = model.generate(**inputs,
                            max_new_tokens=500,
                            num_beams=5,
                            no_repeat_ngram_size=4,
                            early_stopping=True,
                            eos_token_id=tokenizer.eos_token_id, 
                            pad_token_id=tokenizer.eos_token_id, 
                            )
   output = tokenizer.decode(output[0], skip_special_tokens=False)
   ss = output.split('<|end_of_turn|>')[-2:]
   ss = ''.join(ss).replace('GPT4 Correct Assistant: ', '').lstrip()
   
   return ss
   
   

In [8]:
ask("What is your name?")

'<s> GPT4 Correct User: What is your name? My name is GPT-4.</s>'

In [9]:
ask("use C# code write HELLO string to console")

'<s> GPT4 Correct User: use C# code write HELLO string to console Sure, here\'s an example C# code that writes "HELLO" to the console:\n\n```csharp\nusing System;\n\nclass Program\n{\n    static void Main(string[] args)\n    {\n        Console.WriteLine("HELLO");\n    }\n}\n```</s>'

In [18]:
ask("What is your name?")

'<s> GPT4 Correct User: What is your name? My name is GPT-4.</s>'