In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch
from pathlib import Path
import re

# Create a class object to generate responses to user questions
class ChatGenerator:
    def __init__(self):
        self.checkpoint = (Path.cwd()).parent / 'Llama-3.2-3B-Instruct' # Set the model checkpoint path
        self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint) # Load tokenizer
        self.model = AutoModelForCausalLM.from_pretrained(              # Define the wrapper model
            self.checkpoint,
            torch_dtype=torch.bfloat16,
            device_map="cuda"  # Automatically uses GPU if available
        )
        self.text_generator = transformers.pipeline(                    # Load the model to memory cache
            "text-generation",   
            model=self.model,
            tokenizer=self.tokenizer,
            torch_dtype=torch.bfloat16,
            device_map="cuda"  # Auto-detect and use GPU
        )
    # Create a function to format a question and generate a response
    def generate_response(self, question: str):
        prompt = (                                         #This prompt may be unnecessary
            "<|begin_of_text|>"
            "<|start_header_id|>user<|end_header_id|>"
            f"{question}"
            "<|eot_id|>"
            "<|start_header_id|>assistant<|end_header_id|>"
        )
        sequences = self.text_generator(
            prompt,
            do_sample=True,
            top_k=2,
            num_return_sequences=1,
            eos_token_id=self.tokenizer.eos_token_id,
            truncation=True,
            max_length=6000,
        )

        # Extract and return generated responses
        for sequence in sequences:
            match = re.search(r'<\|end_header_id\|>\\n\\n(.*?)\'\}', str(sequence), re.DOTALL)
            response = match.group(1)  # Gets everything after the last '<|end_header_id|>'
        return sequences, response


chat_bot = ChatGenerator()   #Initialize Class Object
chat_history = ""   # Track conversation history


while True:
    question = input("Enter your question: ")
    if question.lower() == "exit":
        break
    sequences, response = chat_bot.generate_response(chat_history + question)
    chat_history += str(sequences)
    print(response)