In [1]:
import os
from dotenv import dotenv_values
import torch

from src.QA_Chain import QA_Chain

In [2]:
# Set torch to use the GPU memory at 80% capacity
if torch.cuda.is_available():
    print("GPU found lessgoo..., setting memory fraction to 80%")
    torch.cuda.set_per_process_memory_fraction(0.8)

config = dotenv_values(".env")

# MODEL_NAME = 'meta-llama/Llama-3.2-1B'
MODEL_NAME = 'meta-llama/Llama-3.2-1B-Instruct'
# MODEL_NAME = 'distilgpt2'
EMBEDDING_NAME = 'sentence-transformers/paraphrase-MiniLM-L6-v2'

EMBEDDING_MODEL_PATH = os.path.abspath(config.get(
    'EMBEDDING_MODEL_PATH', './models/embedding_model'))

VECTOR_STORE_PATH = os.path.abspath(config.get(
    'VECTOR_STORE_PATH', './models/vector_store'))

LLM_MODEL_PATH = os.path.abspath(
    config.get('LLM_MODEL_PATH', './models/model'))

print("Embedding model path:", EMBEDDING_MODEL_PATH)
print("Vector store path:", VECTOR_STORE_PATH)
print("LLM Model path:", LLM_MODEL_PATH)

DOCUMENT_PARENT_DIR_PATH = os.path.abspath(
    config.get('DOCUMENT_DIR_PATH', './documents'))
DOCUMENT_DIR_NAME = 'test'
DOCUMENT_DIR_PATH = os.path.join(DOCUMENT_PARENT_DIR_PATH, DOCUMENT_DIR_NAME)

print("Loading documents from directory:", DOCUMENT_DIR_PATH)

GPU found lessgoo..., setting memory fraction to 80%
Embedding model path: c:\Users\bryan\Documents\GitHub\NTU-FYP-Chatbot-AI\models\embedding_model
Vector store path: c:\Users\bryan\Documents\GitHub\NTU-FYP-Chatbot-AI\models\vector_store
LLM Model path: c:\Users\bryan\Documents\GitHub\NTU-FYP-Chatbot-AI\models\llm_model
Loading documents from directory: c:\Users\bryan\Documents\GitHub\NTU-FYP-Chatbot-AI\documents\test


In [3]:
try:
    qa_chain.destroy()
    print("Someone is already running the QA Chain, destroying the old instance.")
except:
    pass

qa_chain = QA_Chain()

qa_chain.load_embeddings_model(
    EMBEDDING_NAME, embedding_model_path=EMBEDDING_MODEL_PATH)

vector_store_path = os.path.join(
    VECTOR_STORE_PATH, f"{DOCUMENT_DIR_NAME}_{qa_chain.embeddings.model_name}")
file_paths = [os.path.join(DOCUMENT_DIR_PATH, filepath)
              for filepath in os.listdir(DOCUMENT_DIR_PATH)]
file_paths_abs = [os.path.abspath(file_path) for file_path in file_paths]

qa_chain.load_vector_store(vector_store_path, file_paths_abs)

✅ Loaded vector store from local storage.


In [4]:
# IMPORTANT: MAKE SURE YOU'RE AUTHENTICATED AND HAVE ACCESS
from langchain_core.prompts import ChatPromptTemplate

custom_prompt_template = ChatPromptTemplate.from_template("""
Persona:
You are an AI model that provides short, concise answers.
If you do not know the answer, respond with "I don't know."
Do not make up information. 

Only generate one answer, do not generate questions.

Context:
{context}

Question:
{input}?

Answer:
""")

# custom_prompt_template = None

qa_chain.initialize_llm(model_name=MODEL_NAME,
                        max_new_tokens=512, model_path=LLM_MODEL_PATH, temperature=0.5)

# For top_k some reason >= 4, then the thing cannot work
qa_chain.initialize_qa_chain(prompt_template=custom_prompt_template)

🔄 Loading model from c:\Users\bryan\Documents\GitHub\NTU-FYP-Chatbot-AI\models\llm_model\meta-llama/Llama-3.2-1B-Instruct...


Some parameters are on the meta device because they were offloaded to the cpu.


In [5]:
query = "is this a pass fail course?"

result = qa_chain.query(query)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


In [6]:
print(query)
print("===== Answer =====")
print(result['answer'])

is this a pass fail course?
===== Answer =====
No, it is not a Pass/Fail course. You will be graded based on the components mentioned below. It is mandatory that you "attempt" at least 80% of the graded components to Pass this course. This means you can't simply drop the Mini-Project (30%) or the AI Theory Quiz (25%). You may miss a few components, if you have to, but the total weight of the components you miss should not be more than 20%. Be careful, and choose wisely, in case you do need to miss out on any component.
