In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer
import torch
from queue import Queue
from threading import Thread

In [2]:
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
bnb_config = BitsAndBytesConfig(load_in_4bit=True, 
                                            bnb_4bit_quant_type='nf4',
                                            bnb_4bit_use_double_quant=True,
                                            bnb_4bit_compute_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(model_id)


device = "cuda:0"

model = AutoModelForCausalLM.from_pretrained(
        model_id,
        trust_remote_code=True,
        # config=model_config,
        quantization_config=bnb_config,
        torch_dtype=torch.float16,
        # load_in_8bit=True,
        device_map=device,
        # use_auth_token=hf_auth
    )

[2024-01-30 11:15:33,063] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
max_new_tokens = 512

In [4]:
def remove_inst_tags(input_string):
    while '[INST]' in input_string:
        start_index = input_string.find('[INST]')
        end_index = input_string.rfind('[/INST]') + len('[/INST]')
        input_string = input_string[:start_index] + input_string[end_index:]
    return input_string

In [5]:

class CustomStreamer(TextStreamer):

    def __init__(self, queue,tokenizer, skip_prompt,**decode_kwargs) -> None:
        super().__init__(tokenizer, skip_prompt, **decode_kwargs)
        self._queue = queue
        self.stop_signal=None
        self.timeout = 1
        
    def on_finalized_text(self, text: str, stream_end: bool = False):
        self._queue.put(text)
        if stream_end:
            self._queue.put(self.stop_signal)
            

In [6]:
text = "Hi, my name is Jaswant"

messages = [{
    "role": "user",
    "content": text
}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(device)

streamer_queue = Queue()
streamer = CustomStreamer(streamer_queue, tokenizer, True)
generation_kwargs = dict(inputs=inputs, streamer=streamer, 
                         pad_token_id=tokenizer.eos_token_id, 
                         max_new_tokens=64, temperature=0.1,
                         do_sample=True)

# outputs = model.generate(inputs, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id)

# decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
# output = remove_inst_tags(decoded_output)
# print(output)

In [7]:
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

while True:
    value = streamer_queue.get()
    if value == None:
        break
    print(value, end="")
    streamer_queue.task_done()
    # await asyncio.sleep(0.1)



Hello Jaswant, nice to meet you! How can I help you today? If you have any specific questions or topics you'd like to discuss, feel free to ask. I'm here to provide information and answer any questions you might have to the best of my ability. Let me know if there'

In [8]:
outputs = model.generate(inputs=inputs, streamer=streamer, 
                         pad_token_id=tokenizer.eos_token_id, max_new_tokens=64, temperature=0.1, do_sample=True)

decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)

output = remove_inst_tags(decoded_output)
output

" Hello Jaswant, nice to meet you! How can I help you today? If you have any specific questions or topics you'd like to discuss, feel free to ask. I'm here to provide information and answer any questions you might have to the best of my ability. Let me know if there'"

In [11]:
streamer_queue = Queue()
streamer = CustomStreamer(streamer_queue, tokenizer, True)

def generate(messages):
    inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(device)
    
    outputs = model.generate(inputs=inputs, streamer=streamer, 
                         pad_token_id=tokenizer.eos_token_id, 
                         max_new_tokens=max_new_tokens, 
                         temperature=0.1, do_sample=True)

    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    output = remove_inst_tags(decoded_output)
    print('\n----------Displaying Output-------------\n')
    print(output)

In [12]:
text = "How to print Hello World in python ?"

messages = [{
    "role": "user",
    "content": text
}]

thread = Thread(target=generate, kwargs=dict(messages=messages))
thread.start()

while True:
    value = streamer_queue.get()
    if value == None:
        break
    print(value, end="")
    streamer_queue.task_done()

To print the text "Hello World" in Python, you can use the `print()` function. Here's an example of how to use it:

```python
# This is a comment - it will not be executed

# Use the print function to print "Hello World"
print("Hello World")
```

When you run this code, the output will be:

```
Hello World
```

You can also print multiple lines by separating them with commas:

```python
print("Hello World", "This is a second line")
```

The output will be:

```
Hello World
This is a second line
```</s>
----------Displaying Output-------------

 To print the text "Hello World" in Python, you can use the `print()` function. Here's an example of how to use it:

```python
# This is a comment - it will not be executed

# Use the print function to print "Hello World"
print("Hello World")
```

When you run this code, the output will be:

```
Hello World
```

You can also print multiple lines by separating them with commas:

```python
print("Hello World", "This is a second line")
```

The outp