In [1]:
!pip install farm-haystack transformers ollama
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!pip install sentence-transformers
!pip install "farm-haystack[inference]"

In [118]:
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes import EmbeddingRetriever
from haystack.schema import Document
import ollama

In [119]:
document_store = InMemoryDocumentStore(embedding_dim=384)

key_terms = open('terms.csv').read().split('\n')

documents = [Document(content=term) for term in key_terms]

document_store.write_documents(documents)

In [120]:
retriever = EmbeddingRetriever(
    document_store=document_store,
    embedding_model="sentence-transformers/all-MiniLM-L6-v2",
    use_gpu=False 
)

document_store.update_embeddings(retriever)

Batches: 100%|██████████| 224/224 [00:02<00:00, 80.89it/s]cs/s]
Documents Processed: 10000 docs [00:02, 3526.80 docs/s]         


In [144]:
client = ollama.Client()

def ask_llama(question, relevant_key_terms):
    prompt = f"""
        You are selecting the three most relevant key terms from the UN database relevant to the following question:
        Question: {question}
        Key Terms: {relevant_key_terms}
        Output only the selected key terms, separated by commas, nothing else.
        """

    return client.chat(model="llama3.2", messages=[{"role": "user", "content": prompt}])['message']['content']


In [145]:
def get_relevant_key_terms(question, top_k=20):
    results = retriever.retrieve(query=question, top_k=top_k)
    return [doc.content for doc in results]

question = "war in europe"
relevant_key_terms = get_relevant_key_terms(question)
print("Relevant Key Terms:", relevant_key_terms)

Batches: 100%|██████████| 1/1 [00:00<00:00, 138.89it/s]


Relevant Key Terms: ['WAR', 'WORLD WAR (1914-1918)', 'WAR CRIMES', 'ARMED CONFLICTS', 'EUROPE', 'WORLD WAR (1939-1945)', 'WAR PROPAGANDA', 'NUCLEAR WAR', 'WAR PREVENTION', 'WESTERN EUROPE', 'CIVIL WAR', 'WAR CRIMINALS', 'PRISONERS OF WAR', 'EUROPEAN STUDIES', 'ETHNIC CONFLICT', 'EASTERN EUROPE', 'ARMIES', 'SOUTHERN EUROPE', 'MILITARY OCCUPATION', 'WAR VICTIMS']


In [148]:
is_answer_valid = False
while not is_answer_valid:
    answer = ask_llama(question, relevant_key_terms)
    print("Answer:", answer)

    answer_split = [term.strip() for term in answer.split(',')]
    is_answer_valid = isinstance(answer, str) and len(answer_split) == 3
    is_answer_valid = is_answer_valid and all([term.strip() in relevant_key_terms for term in answer.split(',')])
    print(is_answer_valid)

Answer: WORLD WAR (1914-1918), WORLD WAR (1939-1945), ARMED CONFLICTS
True


In [149]:
print(answer_split)
output = f"subjectheading:[{answer_split[0]}]subjectheading:[{answer_split[1]}]subjectheading:[{answer_split[2]}]"
print(output)

['WORLD WAR (1914-1918)', 'WORLD WAR (1939-1945)', 'ARMED CONFLICTS']
subjectheading:[WORLD WAR (1914-1918)]subjectheading:[WORLD WAR (1939-1945)]subjectheading:[ARMED CONFLICTS]
